from __future__ import annotations
# Python Modules
import abc
import dataclasses
import datetime
import json
import logging
import os
import re
from contextlib import contextmanager
from io import IOBase, StringIO
from pathlib import Path
from typing import (
Any,
Callable,
ClassVar,
Iterable,
Literal,
Mapping,
Optional,
Type,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
# Project Modules
from rwskit.dataclasses_ import DataclassRegistry, immutable_dataclass
from rwskit.types_ import get_qualified_name
# 3rd Party Modules
import yaml
import yaml_include
from dateutil.parser import parse as parse_date
from icontract import ensure, require
from yamlable import YamlAble
[docs]
log = logging.getLogger(__name__)
[docs]
T = TypeVar("T", bound="YamlConfig")
[docs]
E = TypeVar("E", bound="EnvironmentConfig")
[docs]
TypeParser = Callable[[str], I]
# Enable include directives inside yaml files
@contextmanager
def _enable_yaml_include(path: Path):
tags = ("!inc", "!include")
loaders = (yaml.Loader, yaml.BaseLoader, yaml.SafeLoader, yaml.FullLoader)
base_dir = path.parent
for tag in tags:
for loader in loaders:
yaml.add_constructor(
tag, yaml_include.Constructor(base_dir=base_dir), loader
)
yield
for tag in tags:
for loader in loaders:
del loader.yaml_constructors[tag]
# According to the python documentation the following (commented) code should
# allow any class that uses the 'YamlConfigMeta' to automatically be viewed
# as a dataclass from the perspective of the type checker. Unless I am doing
# something wrong (very possible), it does not work for me and PyCharm
# does not detect that subclasses should be considered as dataclasses. The
# code functions correctly, but PyCharm will issue warnings about unresolved
# attributes for all subclasses and cannot provide autocompletion.
#
# However, if you annotate each individual class with @dataclass_transform
# PyCharm will recognize the class as a dataclass. This is obviously not ideal
# but better than nothing.
# @dataclass_transform(kw_only_default=True, frozen_default=True)
# class DataclassMeta(type):
# def __new__(mcs, name, bases, namespace, **kwargs):
# new_cls = super().__new__(name, bases, namespace, **kwargs)
# return pydantic_dataclass(new_cls, frozen=True, kw_only=True)
#
#
# class YamlConfigMeta(DataclassMeta, type(YamlAble)):
# pass
[docs]
class YamlConfig(YamlAble):
"""A base class for serializable configuration objects.
Classes that inherit from this class can easily be serialized to and from
YAML files. Given a YAML file, the class can be reconstructed as long
as the YAML attributes can be uniquely mapped to a subclass of
:class:`YamlConfig`.
Additionally, the configuration can be split across multiple files for
better modularity using the ``!include`` directive.
Examples
--------
>>> class ChildConfig(YamlConfig):
... id: int
... name: str
... timestamp: datetime.datetime
>>> class ParentConfig(YamlConfig):
... parent_attr: str = "parent_attr_value"
... child_attr:
>>> expected_config = ParentConfig(
... id=1,
... child_config=ChildConfig(
... id=2,
... name="child_config",
... timestamp=datetime.datetime.now()
... )
>>> plain_yaml = '''
... child_config:
... id: 2
... name: child_config
... timestamp: 2024-11-19 13:55:34.064388
... id: 1
... '''
>>> from_plain_yaml = YamlConfig.loads_yaml(plain_yaml)
>>> assert from_plain_yaml == expected_config
The ``!yamlable`` tag can be used to explicitly tell the YAML parser
which class to construct. The syntax is ``!yamlable/<fully_qualified_class_name>``.
>>> tagged_yaml = '''
... !yamlable/my_package.my_module.ParentConfig
... child_config: !yamlable/my_package.my_module.ChildConfig
... id: 2
... name: child_config
... timestamp: 2024-11-19 13:55:34.064388
... id: 1
... '''
>>> assert YamlConfig.loads_yaml(tagged_yaml) == expected_config
You can use the ``!include`` directive to include other YAML files.
For example, assume you have the following two YAML files:
.. code-block:: yaml
# child_config.yaml
id: 2
name: "child_config"
timestamp: 2024-11-19 13:55:34.064388
.. code-block:: yaml
# parent_config.yaml
id: 1
child_config: !include child_config.yaml
You can load the parent config using ``YamlConfig.load_yaml`` as follows:
>>> YamlConfig.load_yaml("parent_config.yaml")
"""
# Maintain a registry of all classes that inherit from this base class.
# Although, I don't plan to use this in a multithreaded context, it is
# pretty easy to make it thread safe with a lock.
__registry: ClassVar[DataclassRegistry] = DataclassRegistry()
[docs]
default_type_parsers = {
int: int,
float: float,
bool: lambda x: x.lower() in ("true", "1", "t", "y", "yes"),
str: lambda x: x,
datetime.datetime: parse_date,
datetime.date: lambda x: parse_date(x).date(),
}
@require(
lambda self: dataclasses.is_dataclass(self), "The class must be a dataclass."
)
def __init__(self):
pass
[docs]
def __init_subclass__(cls: Type[YamlConfig], **kwargs: Any):
"""Initialize subclasses to make them suitable configuration objects.
* Automatically assign the ``__yaml_tag_suffix__`` using the fully
qualified class name.
* Convert the class to a ``pydantic.dataclasses.dataclass`` that has
``frozen=True`` and ``kw_only=True``.
Parameters
----------
kwargs : Any
"""
super().__init_subclass__(**kwargs)
cls.__yaml_tag_suffix__ = get_qualified_name(cls)
# Add the subclass to the registry so we can dynamically load the class
# without having to import it first.
YamlConfig.__registry.register(cls) # noqa
@classmethod
[docs]
def get_registered_classes(cls) -> set[Type[YamlConfig]]:
"""Get the set of classes currently registered as configuration objects.
Returns
-------
set[Type[YamlConfig]]
The set of registered yaml config classes.
"""
return set(cls.__registry)
[docs]
def dumps_plain_yaml(self) -> str:
"""
Represent the class as plain YAML without any tags.
.. note::
It may not be possible to reconstruct the python object from this
string.
Returns
-------
str
The object as plain YAML without any tags.
"""
def _dataclass_to_dict(obj: Any):
if dataclasses.is_dataclass(obj):
return {
k: _dataclass_to_dict(v) for k, v in dataclasses.asdict(obj).items()
}
elif isinstance(obj, str):
return obj
elif isinstance(obj, Mapping):
return {k: _dataclass_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, Iterable):
return [_dataclass_to_dict(v) for v in obj]
else:
return obj
def _transform_sets_to_lists(obj: Any):
if isinstance(obj, set):
return [_transform_sets_to_lists(e) for e in obj]
elif isinstance(obj, Mapping):
return {k: _transform_sets_to_lists(v) for k, v in obj.items()}
elif isinstance(obj, str):
return obj
elif isinstance(obj, Iterable):
return [_transform_sets_to_lists(v) for v in obj]
else:
return obj
d = _transform_sets_to_lists(_dataclass_to_dict(self))
return yaml.safe_dump(d)
@classmethod
@require(
lambda file_path_or_stream: isinstance(
file_path_or_stream, (str, Path, IOBase, StringIO)
),
"'file_path_or_stream' must be a string, pathlib.Path, IOBase, or StringIO object.",
)
[docs]
def load_yaml(
cls: Type[T],
file_path_or_stream: str | Path | IOBase | StringIO,
safe: bool = True,
) -> T:
raw_config = cls._load_raw_yaml(file_path_or_stream, safe)
# If the deserialized object is an instance of the loader class then
# we can simply return the object (i.e., it contained all the tags
# necessary to reconstruct the config).
if isinstance(raw_config, cls):
return raw_config
# Since a config file represents a data class the top level
# object of a plain yaml config must be a dictionary (i.e., the
# attributes of the dataclass).
if not isinstance(raw_config, Mapping):
raise TypeError(f"Invalid YAML config file.")
return YamlConfig.__registry.construct_registered_dataclass(raw_config)
@classmethod
def _load_raw_yaml(
cls: Type[T],
file_path_or_stream: str | Path | IOBase | StringIO,
safe: bool = True,
) -> Any:
# Deserialize the yaml into an object. If the yaml contains known
# tags it will return the appropriate python classes, otherwise
# it will return primitive python types (e.g., ints, floats, lists,
# dicts, etc.)
yaml_loader = yaml.safe_load if safe else yaml.load
if isinstance(file_path_or_stream, (str, Path)):
with open(file_path_or_stream, "rt") as fh:
# This allows using paths relative to the input config file
# rather than paths relative to the current working directory.
with _enable_yaml_include(file_path_or_stream):
return yaml_loader(fh)
else:
with file_path_or_stream as fh:
return yaml_loader(fh)
@immutable_dataclass
[docs]
class EnvironmentConfig(abc.ABC):
"""A mixin clas for configuration objects to help constructing them from environment variables.
This mixin adds a method :meth:`from_environment` that will try to parse
environment variables into the correct type for the dataclass. All you
have to do is implement the :meth:`environment_mapping` method to provide
a mapping between environment variable names and the dataclass field names.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
_default_type_parsers = {
bool: lambda x: x.lower() in ("true", "1", "t", "y", "yes"),
int: int,
float: float,
str: lambda x: x,
Literal: lambda x: x,
datetime.datetime: parse_date,
datetime.date: lambda x: parse_date(x).date(),
re.Pattern: lambda x: re.compile(x),
}
_dataclass_fields = None
@classmethod
[docs]
def get_default_type_parsers(cls: Type[E]) -> dict[Type[I], TypeParser]:
"""Get the default type parsers for this class."""
return cls._default_type_parsers.copy()
@classmethod
@abc.abstractmethod
@ensure(
lambda cls, result: all(cls._is_valid_field(v) for v in result.values()),
"All values in the mapping must be a valid field name.",
)
[docs]
def environment_mapping(cls: Type[E]) -> Mapping[str, str]:
"""
Returns a mapping from environment variable names to their
corresponding field name.
Returns
-------
Mapping[str, str]
Raises
------
ViolationError
If any of the returned values are not a field name of this
class. Note, not all fields need to be included in the mapping, but
any entry that is included must correspond to a field name.
"""
pass
@classmethod
def _is_valid_field(cls, field_name: str) -> bool:
# A valid field is any field that can be used in the constructor.
field_names = {f.name for f in dataclasses.fields(cls) if f.init} # noqa
return field_name in field_names
@classmethod
[docs]
def from_environment(
cls: Type[E], type_parsers: Optional[dict[Type[I], TypeParser]] = None, **kwargs
) -> E:
"""
Create an instance from environment variables.
Environment variables can represent python primitive values,
date objects, datetime objects. They can also be lists, sets, tuples,
or dicts of these types. Collections are space separated string.
Multiword strings should be enclosed in double quotes. Dictionary
key value pairs are delimited by '='. Spaces are not allowed in
keys or values.
Parameters
----------
type_parsers: dict[Type[I], TypeParser]
A mapping from python types to functions that parse a string
into that type. These will be combined and override the default
parsing rules described above.
kwargs
Additional keyword arguments to override the values from the
environment or composite fields (e.g., other dataclasses),
which are constructed externally.
Returns
-------
"""
type_parsers = cls.get_default_type_parsers() | (type_parsers or {})
environment_kwargs = cls.get_environment_kwargs(type_parsers)
init_kwargs = environment_kwargs | kwargs
return cls(**init_kwargs)
@classmethod
[docs]
def get_environment_kwargs(
cls, type_parsers: dict[Type[I], TypeParser]
) -> dict[str, I]:
"""
Get the keyword arguments for this dataclass that are available
from the environment.
Parameters
----------
type_parsers : dict[Type, TypeParser]
The type parsers to use when parsing environment variables.
Returns
-------
dict[str, Any]
The keyword arguments as a dictionary.
"""
type_parsers = cls.get_default_type_parsers() | (type_parsers or {})
# Map the fields to the environment variable (string) values, for all
# non-empty values.
value_map = {
fn: ev
for vn, fn in cls.environment_mapping().items()
if (ev := os.environ.get(vn)) is not None # noqa
}
# For each field, attempt to parse the value.
resolved_types = get_type_hints(cls)
return {
fn: cls.parse_value_from_string(
ev, cls._get_actual_type(resolved_types[fn]), type_parsers
)
for fn, ev in value_map.items()
}
# A utility method to make sure we get the actual field type in case
# it is marked optional.
@classmethod
def _get_actual_type(cls, t: Type[I]) -> Type[I]:
if get_origin(t) is Union:
args = get_args(t)
if len(args) == 2 and type(None) in args:
args = [a for a in args if a is not type(None)]
return args[0]
else:
raise NotImplementedError(
f"Unions of more than one type are not supported"
)
return t
@classmethod
[docs]
def get_dataclass_field_map(cls) -> dict[str, dataclasses.Field]:
"""Get a mapping from field names to their dataclass fields."""
if cls._dataclass_fields is None:
cls._dataclass_fields = {f.name: f for f in dataclasses.fields(cls)} # noqa
return cls._dataclass_fields
@classmethod
[docs]
def get_field_from_name(cls, field_name: str) -> dataclasses.Field:
"""Get the dataclass field from its name."""
return cls.get_dataclass_field_map()[field_name]
@classmethod
[docs]
def parse_value_from_string(
cls,
input_string: str,
target_type: Type[I],
type_parsers: Optional[dict[Type[I], TypeParser]] = None,
):
"""
Try to parse an input string into the given python ``target_type``.
The input string can be a single value that can be parsed by
the default parsing rules or any of the rules implemented by the
given ``type_parsers``. By default, the parseable values are: str, int,
float, bool, datetime.datetime, datetime.date, re.Pattern, and any
class that can be constructed from a single string argument (e.g.,
``pathlib.Path``.
To handle more complex ``target_types``, the input string can also
be a JSON string. The string will be parsed as JSON and the resulting
object will be traversed to try to convert the leaf values using the
all the available parsing rules.
Parameters
----------
input_string : str
The string to parse.
target_type : Type
The python type to parse the value into.
type_parsers : dict[Type, TypeParser], optional
An optional dictionary that maps a type to a function that
parses a string to that type. These rules will augment the default
parsing rules given by :meth:`get_default_type_parsers` and will
take precedent over the default rules if there is a conflict.
Returns
-------
An instance of the target type parsed from the environment value.
"""
type_parsers = cls.get_default_type_parsers() | (type_parsers or {})
def parse_base_type(value: str, hint: Type[I]):
"""Parse a base type (no origin) from a string."""
def constructor_parser(x: str) -> I:
"""Try parsing ``x`` assuming ``hint`` is callable"""
return hint(x)
hint = get_origin(hint) or hint
try:
return type_parsers.get(hint, constructor_parser)(value)
except Exception as e:
raise ValueError(f"Unable to parse '{value}' as {hint}.") from e
def parse_value(value: Any, hint: Type[I]):
"""Recursively try to parse an arbitrary value to the given type."""
origin = get_origin(hint)
if origin is Literal:
return parse_value(value, str)
if origin is list:
inner_type = get_args(hint)[0]
return [parse_value(e, inner_type) for e in value]
elif origin is set:
inner_type = get_args(hint)[0]
return set(parse_value(e, inner_type) for e in value)
elif origin is tuple:
inner_types = get_args(hint)
return tuple(parse_value(e, t) for e, t in zip(value, inner_types))
elif origin is dict:
key_type, value_type = get_args(hint)
return {
parse_value(k, key_type): parse_value(v, value_type)
for k, v in value.items()
}
elif origin is Union:
inner_types = [t for t in get_args(hint) if t is not type(None)]
for inner_type in inner_types:
try:
return parse_value(value, inner_type)
except Exception as e:
pass
raise ValueError(f"Unable to parse '{value}' as an of {inner_types}.")
elif origin is None:
value_type = type(value)
if value_type in (int, float, bool):
# Note, I shouldn't have to also check that the target_type
# matches the value_type, because pydantic will do that
# for us when it constructs the object.
return value
elif value_type is str:
return parse_base_type(str(value), hint)
else:
raise ValueError(f"Unable to parse '{value}' as {hint}.")
raise ValueError(f"Unsupported type: {hint}")
try:
return parse_value(json.loads(input_string), target_type)
except json.JSONDecodeError:
return parse_base_type(input_string, target_type)