#!/usr/bin/python3
# pylint: disable-msg=E0611,W0621,C0103,E1101,W0601
#
# SPDX-FileCopyrightText: 2004-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

"""Create and modify dns objects easily."""


import sys
import time
from argparse import REMAINDER, ArgumentParser, FileType, Namespace, RawTextHelpFormatter
from logging import getLogger

import univention.logging


univention.logging.basicConfig(filename='/var/log/univention/directory-manager-cmd.log', univention_debug_level=1)

log = getLogger('ADMIN')

import univention.admin.config  # noqa: E402
import univention.admin.filter  # noqa: E402
import univention.admin.modules  # noqa: E402
import univention.admin.objects  # noqa: E402
import univention.admin.uexceptions  # noqa: E402
from univention.admin import uldap  # noqa: E402
from univention.admin.handlers.dns import alias  # noqa: E402
from univention.admin.handlers.dns import forward_zone  # noqa: E402
from univention.admin.handlers.dns import host_record  # noqa: E402
from univention.admin.handlers.dns import ns_record  # noqa: E402
from univention.admin.handlers.dns import ptr_record  # noqa: E402
from univention.admin.handlers.dns import reverse_zone  # noqa: E402
from univention.admin.handlers.dns import srv_record  # noqa: E402
from univention.admin.handlers.dns import txt_record  # noqa: E402
from univention.config_registry import ucr  # noqa: E402


def parse() -> Namespace:
    """Parse command line."""
    description = '''
%(prog)s my.dns.zone add	srv   ldap tcp 0 100 7389 master.fqdn.
%(prog)s my.dns.zone remove	srv   ldap tcp 0 100 7389 master.fqdn.
%(prog)s my.dns.zone add	txt   entry-name  "Some text"
%(prog)s 192.168.122 add	ptr   42  host.fqdn.  --reverse
%(prog)s my.dns.zone add	cname univention-repository apt.knut.univention.de.
%(prog)s my.dns.zone add	a     host  1.2.3.4  5.6.7.8
%(prog)s my.dns.zone add	zone  root@fqdn 1 28800 7200 25200 10800 ns.fqdn.
'''
    global parser
    parser = ArgumentParser(formatter_class=RawTextHelpFormatter, description=description)
    parser.add_argument(
        '--ignore-missing-zone',
        action='store_true',
        help='Skip if zone does not exist')
    parser.add_argument(
        '--ignore-exists',
        action='store_true',
        help='Skip if entry already exists')
    parser.add_argument(
        '--quiet',
        action='store_true',
        help='Turn off verbose messages')  # not implemented
    parser.add_argument(
        '--reverse',
        action='store_true',
        help='Modify revers zone instead of forward zone')
    parser.add_argument(
        '--overwrite',
        action='store_true',
        help='Overwrite existing record')
    parser.add_argument(
        '--stoptls',
        action='store_true',
        help='Disable TLS')

    group_ldap = parser.add_argument_group("LDAP BIND")
    group_ldap.add_argument(
        '--binddn',
        help='LDAP bind DN')
    group_pwd = group_ldap.add_mutually_exclusive_group()
    group_pwd.add_argument(
        '--bindpwd',
        help='LDAP bind password')
    group_pwd.add_argument(
        '--bindpwdfile',
        type=FileType("r"),
        help='LDAP bind password file')

    parser.add_argument(
        '--timeout', type=int, default=120,
        help='LDAP connection timeout')

    parser.add_argument(
        'zone',
        help='name of the DNS zone',
    )
    parser.add_argument(
        'command',
        choices=['add', 'remove'],
        help='command',
    )
    parser.add_argument(
        'typ',
        choices=['srv', 'txt', 'ns', 'ptr', 'cname', 'a', 'zone'],
        help='DNS record type',
    )
    parser.add_argument("args", nargs=REMAINDER)

    options = parser.parse_args()
    if options.bindpwdfile:
        with options.bindpwdfile as fd:
            options.bindpwd = fd.read().strip()
    if options.binddn and not options.bindpwd:
        parser.error('authentication error: missing any of --bindpwdfile or --bindpwd')

    return options


def bind() -> tuple[uldap.access, uldap.position]:
    """Bind to LDAP service."""
    start_tls = 0 if options.stoptls else 2
    timeout = time.time() + options.timeout
    while True:
        try:
            if options.binddn and options.bindpwd:
                lo = uldap.access(
                    host=ucr['ldap/master'],
                    port=ucr.get_int('ldap/master/port', 7389),
                    base=ucr['ldap/base'],
                    binddn=options.binddn,
                    bindpw=options.bindpwd,
                    start_tls=start_tls)
                position = uldap.position(lo.base)
            else:
                lo, position = uldap.getAdminConnection(start_tls)
            return lo, position
        except univention.admin.uexceptions.authFail as ex:
            log.warning('authentication error: %s', ex)
            sys.exit('authentication error: %s' % (ex,))
        except (univention.admin.uexceptions.ldapError, uldap.ldap.LDAPError) as ex:
            msg = '%s: timeout while trying to contact LDAP server %s: %s' % \
                (sys.argv[0], ucr['ldap/master'], ex)
            log.warning('%s', msg)
            if time.time() < timeout:
                print(msg, file=sys.stderr)
                time.sleep(10)
            else:
                sys.exit(msg)


def lookup_zone(zone_name: str) -> forward_zone.object | reverse_zone.object:
    """Lookup zone and return UDM object."""
    if not options.reverse:
        zones = forward_zone.lookup(co, lo, '(zone=%s)' % (zone_name,), scope='domain', base=position.getDomain(), unique=True)
    else:
        zones = reverse_zone.lookup(co, lo, '(subnet=%s)' % (zone_name,), scope='domain', base=position.getDomain(), unique=True)
    if not zones:
        if options.ignore_missing_zone:
            sys.exit(0)
        else:
            print('E: Zone %s does not exist.' % (zone_name,), file=sys.stderr)
            sys.exit(1)
    return zones[0]


def add_srv_record(service: str, protocol: str, priority: str, weight: str, port: str, host: str) -> None:
    """Add DNS service record."""
    name = [service, protocol]
    location = [priority, weight, port, host]
    filt = univention.admin.filter.expression('name', name)

    records = srv_record.lookup(co, lo, filt, scope='domain', base=position.getDomain(), superordinate=zone, unique=True)
    if records:
        record = records[0]
    else:
        record = srv_record.object(co, lo, position, superordinate=zone)
        record['name'] = name

    loc = record['location']
    if location in loc:
        return
    loc.append(location)
    record['location'] = loc

    if records:
        record.modify()
    else:
        record.create()


def remove_srv_record(service: str, protocol: str, priority: str, weight: str, port: str, host: str) -> None:
    """Remove DNS service record."""
    name = [service, protocol]
    location = [priority, weight, port, host]
    filt = univention.admin.filter.expression('name', name)

    records = srv_record.lookup(co, lo, filt, scope='domain', base=position.getDomain(), superordinate=zone, unique=True)
    if records:
        record = records[0]
    else:
        print("No record found", file=sys.stderr)
        return

    if location in record['location']:
        record['location'].remove(location)
        if record['location']:
            record.modify()
        else:
            record.remove()
    else:
        print("Does not exist", file=sys.stderr)


def add_txt_record(name: str, text: str) -> None:
    """Add DNS text record."""
    filt = univention.admin.filter.expression('name', name)
    records = txt_record.lookup(co, lo, filt, scope='domain', base=position.getDomain(), superordinate=zone, unique=True)
    if records:
        record = records[0]
    else:
        record = txt_record.object(co, lo, position, superordinate=zone)
        record['name'] = name
        record['zonettl'] = ['80600']

    tmp = record['txt']
    if text in tmp:
        return
    tmp.append(text)
    record['txt'] = tmp

    if records:
        record.modify()
    else:
        record.create()


def add_a_record(name: str, *adresses: str) -> None:
    """Add DNS IPv4 address records."""
    filt = univention.admin.filter.expression('name', name)
    records = host_record.lookup(co, lo, filt, scope='domain', base=position.getDomain(), superordinate=zone, unique=True)
    if records:
        record = records[0]
    else:
        record = host_record.object(co, lo, position, superordinate=zone)
        record['name'] = name
        record['zonettl'] = ['80600']

    tmp = record['a']
    for addr in adresses:
        if addr not in tmp:
            tmp.append(addr)
    record['a'] = tmp

    if records:
        record.modify()
    else:
        record.create()


def add_cname_record(name: str, cname: str) -> None:
    """Add DNS canonical name record."""
    filt = univention.admin.filter.expression('name', name)
    records = alias.lookup(co, lo, filt, scope='domain', base=position.getDomain(), superordinate=zone, unique=True)
    if records:
        record = records[0]
    else:
        record = alias.object(co, lo, position, superordinate=zone)
        record['name'] = name
        record['zonettl'] = ['80600']

    if record['cname'] and not (record['cname'] == cname or options.overwrite or options.ignore_exists):
        print('E: Record exists and points to different address', file=sys.stderr)
        sys.exit(1)
    if record['cname'] == cname:
        return
    record['cname'] = cname

    if records:
        record.modify()
    else:
        record.create()


def add_ns_record(name: str, value: str) -> None:
    """Add DNS NS record."""
    udm_property = 'nameserver'
    filt = univention.admin.filter.expression('zone', name)
    records = ns_record.lookup(co, lo, filt, scope='domain', base=position.getDomain(), superordinate=zone, unique=True)
    if records:
        record = records[0]
    else:
        record = ns_record.object(co, lo, position, superordinate=zone)
        record['zone'] = name
        record['zonettl'] = ['80600']

    tmp = record[udm_property]
    if value in tmp:
        return
    tmp.append(value)
    record[udm_property] = tmp

    if records:
        record.modify()
    else:
        record.create()


def remove_ns_record(name: str, value: str) -> None:
    """Remove DNS NS record."""
    udm_property = 'nameserver'
    filt = univention.admin.filter.expression('zone', name)
    records = ns_record.lookup(co, lo, filt, scope='domain', base=position.getDomain(), superordinate=zone, unique=True)
    if records:
        record = records[0]
    else:
        print("No record found", file=sys.stderr)
        return

    if value in record[udm_property]:
        record[udm_property].remove(value)
        if record[udm_property]:
            record.modify()
        else:
            record.remove()
    else:
        print("Does not exist", file=sys.stderr)


def add_ptr_record(address: str, ptr: str) -> None:
    """Add DNS pointer record."""
    filt = univention.admin.filter.expression('address', address)
    records = ptr_record.lookup(co, lo, filt, scope='domain', base=position.getDomain(), superordinate=zone, unique=True)
    if records:
        record = records[0]
    else:
        record = ptr_record.object(co, lo, position, superordinate=zone)
        record['address'] = address

    if record['ptr_record'] and not (record['ptr_record'] == ptr or ptr in record['ptr_record'] or options.overwrite):
        print('E: Record exists and points to different address', file=sys.stderr)
        sys.exit(1)
    if record['ptr_record'] == ptr:
        return
    record['ptr_record'] = ptr

    if records:
        record.modify()
    else:
        record.create()


def add_zone(contact: str, serial: str, refresh: str, retry: str, expire: str, ttl: str, *nameserver: str) -> None:
    """Add DNS zone."""
    if not options.reverse:
        zone = forward_zone.object(co, lo, position)
        zone['zone'] = zone_name
    else:
        zone = reverse_zone.object(co, lo, position)
        zone['subnet'] = zone_name
    zone['contact'] = contact
    zone['serial'] = serial
    zone['refresh'] = [refresh]
    zone['retry'] = [retry]
    zone['expire'] = [expire]
    zone['ttl'] = [ttl]
    zone['nameserver'] = list(nameserver)
    zone.create()


options = Namespace()
co = None
lo: uldap.access | None = None
position: uldap.position | None = None
lo = position = co = parser = None
zone_name = ""
zone: forward_zone.object | reverse_zone.object | None = None


def main() -> None:
    """Run DNS edit."""
    global options
    options = parse()
    global zone_name
    zone_name = options.zone
    typ = options.typ
    args = options.args

    global lo, position
    lo, position = bind()

    univention.admin.modules.update()
    if typ != 'zone' or options.command == 'remove':
        global zone
        zone = lookup_zone(zone_name)
        position.setDn(zone.dn)
    else:
        position.setDn(univention.admin.config.getDefaultContainer(lo, 'dns/'))

    try:
        if options.command == 'add':
            print('Adding %s record "%s" to zone %s...' % (typ.upper(), ' '.join(args), zone_name))
            if typ == 'srv':
                add_srv_record(*args)
            elif typ == 'txt':
                add_txt_record(*args)
            elif typ == 'ns':
                add_ns_record(*args)
            elif typ == 'ptr' and options.reverse:
                add_ptr_record(*args)
            elif typ == 'cname':
                add_cname_record(*args)
            elif typ == 'a':
                add_a_record(*args)
            elif typ == 'zone':
                add_zone(*args)
            else:
                sys.exit('Unknown type "%s"' % (typ,))
            print('done')
        elif options.command == 'remove':
            if typ == 'srv':
                remove_srv_record(*args)
            elif typ == 'ns':
                remove_ns_record(*args)
            else:
                sys.exit('Unknown type "%s"' % (typ,))
    except univention.admin.uexceptions.objectExists as ex:
        if not options.ignore_exists:
            print('E: Object "%s" exists' % (ex.dn,), file=sys.stderr)
            raise
    except (ValueError, TypeError, univention.admin.uexceptions.valueInvalidSyntax, univention.admin.uexceptions.valueRequired) as ex:
        print('E: failed %s' % (ex,), file=sys.stderr)
        raise


if __name__ == '__main__':
    main()
