#!/usr/bin/python3
#
# Univention Monitoring
#
# SPDX-FileCopyrightText: 2007-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only


import os.path
import re
import shlex
import string
import subprocess
from enum import IntEnum
from os import listdir

from univention.config_registry import ucr
from univention.monitoring import Alert


class RC(IntEnum):
    OK = 0
    OS_ERROR = -1
    SECRET_NOT_FOUND = 1
    LDAP_AUTH = 2
    LDAP_AUTH_TLS = 3
    NO_JOIN_FILE = 4
    LDAP_HOST_DN = 5
    JOIN_SCRIPT_MISSING = 6


class JoinStatusError(Exception):
    pass


class JoinStatusCheck(Alert):

    def write_metrics(self):
        for func, status in (
            (self.check_machine_secret_exists, RC.SECRET_NOT_FOUND),
            (self.check_ldapsearch, RC.LDAP_AUTH),
            (self.check_ldapsearch_tls, RC.LDAP_AUTH_TLS),
            (self.check_ldapsearch_without_hostdn, RC.LDAP_HOST_DN),
            (self.check_joined_file_exists, RC.NO_JOIN_FILE),
            (self.check_missing_joinscripts, RC.JOIN_SCRIPT_MISSING),
        ):
            try:
                func()
            except OSError:
                self.write_metric('univention_join_status', int(RC.OS_ERROR))
                break
            except JoinStatusError as exc:
                self.write_metric('univention_join_status', int(status))
                self.log.debug(str(exc))
                break
        else:
            self.write_metric('univention_join_status', int(RC.OK))

    def check_machine_secret_exists(self):
        try:
            with open('/etc/machine.secret') as fd:
                fd.read()
        except OSError:
            raise JoinStatusError('/etc/machine.secret not found - system not joined yet?')

    def check_ldapsearch(self):
        uri = f"ldap://{ucr['ldap/server/name']}:{ucr['ldap/server/port']}"
        cmd = [
            'ldapsearch',
            '-x',
            '-H', uri,
            '-D', ucr['ldap/hostdn'],
            '-y', '/etc/machine.secret',
            '-b', ucr['ldap/base'],
            '-s', 'base',
        ]
        try:
            subprocess.check_call(cmd, stdout=subprocess.DEVNULL)
        except subprocess.CalledProcessError:
            raise JoinStatusError('auth failed: %s' % ' '.join(shlex.quote(x) for x in cmd))

    def check_ldapsearch_tls(self):
        uri = f"ldap://{ucr['ldap/server/name']}:{ucr['ldap/server/port']}"
        cmd = [
            'ldapsearch',
            '-x',
            '-ZZ',
            '-H', uri,
            '-D', ucr['ldap/hostdn'],
            '-y', '/etc/machine.secret',
            '-b', ucr['ldap/base'],
            '-s', 'base',
        ]
        try:
            subprocess.check_call(cmd, stdout=subprocess.DEVNULL)
        except subprocess.CalledProcessError:
            raise JoinStatusError('auth or TLS failed: %s' % ' '.join(shlex.quote(x) for x in cmd))

    def check_joined_file_exists(self):
        if not os.path.exists('/usr/share/univention-join/.joined') or not os.path.exists('/var/univention-join/joined'):
            raise JoinStatusError("Cannot find /usr/share/univention-join/.joined or /var/univention-join/joined")

    def check_ldapsearch_without_hostdn(self):
        cmd = [
            'ldapsearch',
            '-x',
            '-ZZ',
            '-D', ucr['ldap/hostdn'],
            '-y', '/etc/machine.secret',
            '-b', ucr['ldap/base'],
            '-s', 'base',
        ]
        try:
            subprocess.check_call(cmd, stdout=subprocess.DEVNULL)
        except subprocess.CalledProcessError:
            raise JoinStatusError('auth failed: %s' % ' '.join(shlex.quote(x) for x in cmd))

    def check_missing_joinscripts(self):
        install_files = [
            os.path.join('/usr/lib/univention-install/', fd)
            for fd in listdir('/usr/lib/univention-install/')
            if os.path.splitext(fd)[1] == '.inst'
        ]
        with open('/usr/lib/univention-install/.index.txt') as fd:
            index_file_content = fd.read()

        join_error_count = 0
        for install_file in install_files:
            with open(install_file) as fd:
                service = os.path.splitext(os.path.basename(install_file))[0].lstrip(string.digits)
                version = ""
                for line in fd.readlines():
                    match = re.match('^VERSION=[^0-9]*([0-9]+).*?$', line)
                    if match:
                        version = match.group(1)
                        break
                if f"{service} v{version} successful" not in index_file_content:
                    join_error_count += 1

        if join_error_count > 0:
            raise JoinStatusError(f"{join_error_count} join scripts have to be called")


if __name__ == '__main__':
    JoinStatusCheck.main()
