#!/usr/bin/python3
# SPDX-FileCopyrightText: 2021-2025 Univention GmbH
# SPDX-License-Identifier: AGPL-3.0-only

from __future__ import annotations

import os
from contextlib import contextmanager
from pwd import getpwnam
from typing import TYPE_CHECKING, Any

import lmdb

from univention.ldap_cache.cache.backend import Caches, LdapCache, Shard


if TYPE_CHECKING:
    from collections.abc import Iterator


class LmdbCaches(Caches):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)
        self.env = lmdb.open(self._directory, 2 ** 32 - 1, max_dbs=128)
        self._fix_permissions(self._directory)

    def _fix_permissions(self, db_directory: str) -> None:
        listener_uid = getpwnam('listener').pw_uid
        os.chown(os.path.join(db_directory, 'data.mdb'), listener_uid, -1)
        os.chown(os.path.join(db_directory, 'lock.mdb'), listener_uid, -1)
        os.chmod(os.path.join(db_directory, 'data.mdb'), 0o640)
        os.chmod(os.path.join(db_directory, 'lock.mdb'), 0o640)

    def _add_sub_cache(self, name: str, single_value: bool, reverse: bool) -> LmdbCache:
        sub_db = self.env.open_db(name, dupsort=not single_value)
        cache = LmdbCache(name, single_value, reverse)
        cache.env = self.env
        cache.sub_db = sub_db
        self._caches[name] = cache
        return cache


class LmdbCache(LdapCache):
    @contextmanager
    def writing(self, writer: Any | None = None) -> Iterator[Any]:
        if writer is not None:
            yield writer
        else:
            with self.env.begin(self.sub_db, write=True) as writer:
                yield writer

    def save(self, key: str, values: list[str]) -> None:
        with self.writing() as writer:
            self.delete(key, writer)
            for value in values:
                writer.put(key, value)

    def clear(self) -> None:
        with self.env.begin(write=True) as writer:
            writer.drop(self.sub_db, delete=False)

    def cleanup(self) -> None:
        pass

    def delete(self, key: str, writer: Any = None) -> None:
        with self.writing(writer) as writer:
            writer.delete(key)

    @contextmanager
    def reading(self) -> Iterator[Any]:
        with self.env.begin(self.sub_db) as txn, txn.cursor() as cursor:
            yield cursor

    def __iter__(self) -> Iterator[tuple[str, Any]]:
        with self.reading() as reader:
            yield from reader

    def get(self, key: str) -> Any:
        with self.reading() as reader:
            if self.single_value:
                return reader.get(key)
            else:
                reader.set_key(key)
                return list(reader.iternext_dup())

    def load(self) -> dict[str, Any]:
        ret: dict[str, Any] = {}
        with self._load_key_translations() as translations, self.reading() as reader:
            for key in reader.iternext_nodup():
                translated = translations.get(key)
                if translated is None:
                    continue
                ret[translated] = self.get(key)
        return ret

    @contextmanager
    def _load_key_translations(self) -> Iterator[Any]:
        entry_uuid_db = self.env.open_db('EntryUUID', dupsort=False)
        with self.env.begin(entry_uuid_db) as txn:
            yield txn


class LmdbShard(Shard):
    key = 'entryUUID'
