Source code for graviti.utility.attr

#!/usr/bin/env python3
#
# Copyright 2022 Graviti. Licensed under MIT License.
#
"""Attr related class."""

from itertools import chain
from typing import (
    Any,
    Dict,
    Generic,
    Iterable,
    Iterator,
    Mapping,
    Optional,
    TypeVar,
    Union,
    overload,
)

from graviti.utility.repr import ReprMixin, ReprType

_T = TypeVar("_T")
_D = TypeVar("_D")


class _AttrDict(Generic[_T], ReprMixin):

    _repr_type = ReprType.MAPPING

    def __init__(self) -> None:
        # An external node is one without child branches,
        # while an internal node has at least one child branch.
        self._internals: Dict[str, _AttrDict[_T]] = {}
        self._externals: Dict[str, _T] = {}

    def __len__(self) -> int:
        return sum(map(len, self._internals)) + len(self._externals)

    def __getitem__(self, key: str) -> _T:
        try:
            return self._getitem(key)
        except KeyError:
            return self.__missing__(key)

    def __setitem__(self, key: str, value: _T) -> None:
        try:
            prefix, suffix = key.split(".", 1)
        except ValueError:
            if key in self._internals:
                raise KeyError("Key prefix conflict") from None

            self._externals.__setitem__(key, value)
            return

        if prefix in self._externals:
            raise KeyError("Key prefix conflict")

        module = self._internals.get(prefix)
        if not module:
            module = _AttrDict()
            self._internals.__setitem__(prefix, module)

        module.__setitem__(suffix, value)

    def __contains__(self, key: object) -> bool:
        if not isinstance(key, str):
            return False

        try:
            self._getitem(key)
        except KeyError:
            return False

        return True

    def __missing__(self, key: str) -> _T:
        raise KeyError(key)

    def __iter__(self) -> Iterator[str]:
        for key, value in self._internals.items():
            for suffix in value:
                yield ".".join((key, suffix))

        yield from self._externals

    def __getattr__(self, key: str) -> Any:
        try:
            return self._externals.__getitem__(key)
        except KeyError:
            pass

        try:
            return self._internals.__getitem__(key)
        except KeyError:
            raise AttributeError(
                f"'{self.__class__.__name__}' object has no attribute '{key}'"
            ) from None

    def __dir__(self) -> Iterable[str]:
        return chain(self._externals, self._internals, super().__dir__())

    def _getitem(self, key: str) -> _T:
        try:
            prefix, suffix = key.split(".", 1)
        except ValueError:
            return self._externals[key]

        return self._internals[prefix][suffix]

    def _repr_head(self) -> str:
        return AttrDict.__name__


[docs]class AttrDict(_AttrDict[_T], Mapping[str, _T]): """A dict which allows for attr-style access of values.""" def _repr_head(self) -> str: return self.__class__.__name__ @overload
[docs] def get(self, key: str) -> Optional[_T]: # pylint: disable=arguments-differ ...
@overload def get(self, key: str, default: _D = ...) -> Union[_D, _T]: ... def get(self, key: str, default: Any = None) -> Any: """Return the value for the key if it is in the dict, else default. Arguments: key: The key for dict, which can be any immutable type. default: The value to be returned if key is not in the dict. Returns: The value for the key if it is in the dict, else default. """ try: return self._getitem(key) except KeyError: return default