Source code for rwskit.sqlalchemy

"""Utilities for working with SqlAlchemy."""

# Future Library
from __future__ import annotations

import abc
import copy
import dataclasses
import importlib
import logging

from collections import defaultdict, deque
from contextlib import asynccontextmanager, contextmanager
from functools import cache
from types import ModuleType
from typing import (
    Any,
    Callable,
    ClassVar,
    Iterable,
    Literal,
    Optional,
    Self,
    Type,
    TypeVar,
    cast,
    dataclass_transform,
    get_type_hints,
)

import pydantic
import sqlalchemy as sa

from icontract import ensure, require
from pydantic.dataclasses import dataclass
from sqlalchemy import (
    Engine,
    Index,
    Insert,
    MetaData,
    Table,
    create_engine,
    event,
    orm,
)
from sqlalchemy import insert as default_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as postgres_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import (
    AsyncEngine,
    AsyncSession,
    async_sessionmaker,
    create_async_engine,
)
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Session

from rwskit.collections_ import is_iterable, remove_none_from_dict
from rwskit.config import YamlConfig
from rwskit.hash import ObjectHasher
from rwskit.strings_ import camel_to_snake_case
from rwskit.types_ import get_qualified_name, import_all_modules_in_path, is_optional

[docs] log = logging.getLogger(__name__)
[docs] B = TypeVar("B", bound=DeclarativeBase)
"""A type extending :class:`~sqlalchemy.orm.DeclarativeBase`."""
[docs] M = TypeVar("M", bound="BaseModel")
[docs] TableArgs = dict[str, Any] | tuple[Any, ...]
"""The type of the ``__table_args__`` attribute on a :class:`~sqlalchemy.Table`."""
[docs] DtoModel = Type[pydantic.BaseModel]
"""The base type for a DTO model."""
[docs] SqlOperator = Literal[ "==", "!=", ">", ">=", "<", "<=", "like", "in", "not_in", "is_null", "is_not_null" ]
"""The supported SQL operators for use in an :class:`SqlBinaryExpression`."""
[docs] ConflictResolutionStrategy = Literal["do_nothing", "update"]
[docs] Walkable = orm.DeclarativeBase | orm.Mapper
[docs] WalkPredicate = Callable[[orm.DeclarativeBase, orm.DeclarativeBase], bool]
[docs] class AlchemyEngine: """A wrapper around :class:`sqlalchemy.engine.Engine` that provides additional functionality.""" # noqa @require( lambda engine_or_config, base_model: isinstance( engine_or_config, (AsyncEngine, Engine, DatabaseConnectionConfig) ) and issubclass(base_model, orm.DeclarativeBase), ) def __init__( self, engine_or_config: AsyncEngine | Engine | DatabaseConnectionConfig, base_model: Type[orm.DeclarativeBase], **engine_kwargs: Any, ):
[docs] self.base_model = base_model
self._engine = self._configure_engine(engine_or_config, **engine_kwargs) self._is_async = isinstance(self._engine, AsyncEngine) self._async_session_factory = self._configure_async_session_factory() self._sync_session_factory = self._configure_sync_session_factory() self._dialect = self._engine.dialect.name @property
[docs] def dialect(self) -> str: """Returns the dialect name of the engine.""" return self._dialect
@property
[docs] def supports_async(self) -> bool: """Returns true if the engine supports async mode.""" return self._is_async
@property
[docs] def sync_engine(self) -> Engine: """Get a reference to a synchronous :class:`sqlalchemy.engine.Engine`.""" return self._engine.sync_engine if self.supports_async else self._engine
@property
[docs] def async_engine(self) -> AsyncEngine: """ Get a reference to an asynchronous :class:`sqlalchemy.ext.asyncio.AsyncEngine`. Returns ------- AsyncEngine The engine. Raises ------ ValueError If the engine does not support async mode. """ if self.supports_async: return self._engine raise ValueError("Async engine is not supported.")
[docs] def make_session(self) -> sa.Session: """ Create a session from the :attr:`AlchemyEngine.session_factory`. ..note:: You probably want to use :meth:`AlchemyEngine.session_scope`, but there may be cases where this is more appropriate. Returns ------- Session A :class:`~sqlalchemy.orm.Session` object. """ session = self._sync_session_factory() return session
[docs] def make_connection(self) -> sa.Connection: """ Get a connection to the engine. Returns ------- Connection An :class:`sqlalchemy.Connection`. """ connection = self.sync_engine.connect() return connection
[docs] def make_raw_connection(self) -> sa.PoolProxiedConnection: """ Get a raw connection to the engine. .. note:: See `Working with Driver SQL and Raw DBAPI Connections <https://docs.sqlalchemy.org/en/20/core/connections.html#working-with-driver-sql-and-raw-dbapi-connections>`_ in the SqlAlchemy documentation for the difference between regular and raw connections. Returns ------- PoolProxiedConnection A :class:`sqlalchemy.PoolProxiedConnection`. """ return self.sync_engine.raw_connection()
@contextmanager
[docs] def session_scope(self, expire_on_commit: bool = False) -> sa.Session: """ A context manager for committing successful transactions when the session is complete or rolling back if there was an exception. Parameters ---------- expire_on_commit : bool, default=False If ``True`` model objects will be marked as stale when the next commit. This will invalidate all relationships and raise an exception if they are accessed outside the session. Returns ------- Session An :class:`sqlalchemy.orm.Session`. """ session = self._sync_session_factory() session.expire_on_commit = expire_on_commit try: yield session session.commit() except Exception as e: session.rollback() raise e finally: session.close()
@asynccontextmanager
[docs] async def async_session_scope(self, expire_on_commit: bool = False) -> sa.Session: """ A context manager for committing successful transactions when the session is complete or rolling back if there was an exception. Parameters ---------- expire_on_commit : bool, default=False If ``True`` model objects will be marked as stale when the next commit. This will invalidate all relationships and raise an exception if they are accessed outside the session. Returns ------- Session An :class:`sqlalchemy.orm.Session`. """ session = self._async_session_factory() session.expire_on_commit = expire_on_commit try: yield session await session.commit() except Exception as e: await session.rollback() raise e finally: await session.close()
@contextmanager
[docs] def test_scope(self, expire_on_commit: bool = False) -> sa.Session: """ A session scope for testing. This session will always roll back after exiting the context manager and should not persist any changes to the database. Parameters ---------- expire_on_commit : bool, default=False If ``True`` model objects will be marked as stale when the next commit. This will invalidate all relationships and raise an exception if they are accessed outside the session. Returns ------- Session An :class:`sqlalchemy.orm.Session`. Raises ------ RuntimeError If the user tries to commit changes during the session. """ def raise_on_commit(): """Raise an exception if the session is committed.""" raise RuntimeError( "Session.commit() is not allowed inside a test scope session." ) session = self._sync_session_factory() session.expire_on_commit = expire_on_commit session.commit = raise_on_commit try: yield session finally: session.rollback() session.close()
@asynccontextmanager
[docs] async def async_test_scope(self, expire_on_commit: bool = False) -> sa.Session: """ A session scope for testing. This session will always roll back after exiting the context manager and should not persist any changes to the database. Parameters ---------- expire_on_commit : bool, default=False If ``True`` model objects will be marked as stale when the next commit. This will invalidate all relationships and raise an exception if they are accessed outside the session. Returns ------- Session An :class:`sqlalchemy.orm.Session`. Raises ------ RuntimeError If the user tries to commit changes during the session. """ def raise_on_commit(): """Raise an exception if the session is committed.""" raise RuntimeError( "Session.commit() is not allowed inside a test scope session." ) session = self._async_session_factory() session.expire_on_commit = expire_on_commit session.commit = raise_on_commit try: yield session finally: await session.rollback() await session.close()
@classmethod def _configure_engine( cls, engine_or_config: Engine | AsyncEngine | DatabaseConnectionConfig, **engine_kwargs: Any, ) -> AsyncEngine | Engine: if isinstance(engine_or_config, (Engine, AsyncEngine)): return engine_or_config create = create_async_engine if engine_or_config.use_async else create_engine return create(engine_or_config.url, **engine_kwargs) def _configure_sync_session_factory(self) -> orm.sessionmaker: return orm.sessionmaker(bind=self.sync_engine) def _configure_async_session_factory(self) -> Optional[async_sessionmaker]: if self.supports_async: return async_sessionmaker(bind=self.async_engine) return None
# Note, you can't use 'once' here, because it will literally only run the # listener once, not once per mapper, which is required. @event.listens_for(orm.Mapper, "mapper_configured")
[docs] def handle_after_mapper_configured(mapper: orm.Mapper, cls: Type): """Validate the configuration and create an index on the natural key.""" # Only execute this handler if the event was triggered by a ``BaseModel`` if issubclass(cls, BaseModel): # To facilitate testing the validation code, a Mapper or Table can # include a "validate=False" entry in the 'info' dictionary to # exclude it from being validated by this event handler. info = cls.__table__.info # noqa should_validate = isinstance(info, dict) and info.get("validate", True) if should_validate: cls.validate_mapper()
@dataclass_transform(kw_only_default=True, eq_default=True, unsafe_hash_default=True)
[docs] class BaseModel( orm.DeclarativeBase, orm.MappedAsDataclass, kw_only=True, eq=False, unsafe_hash=False, ): """A base class for creating declarative SqlAlchemy models. Features -------- Table Lookup ~~~~~~~~~~~~ Find any model derived from this base class by their table name. Merging Table Args ~~~~~~~~~~~~~~~~~~ SqlAlchemy does not merge ``__table_args__`` during inheritance. For example, if you have a base class that will set the schema for all child classes, it will not work if the child class defines its own ``__table_args__`` (e.g., to create a multi-column index). This base class provides a function to merge the ``__table_args__`` of parents with their children. This funcitonality is enabled by defining ``__table_args__`` as a ``@declared_attr.directive`` on the class and returning the value of :meth:`merge_table_args`. ``merge_table_args`` accepts one optional parameter, which can be a tuple or dictionary (the expected types of ``__table_args__``) and will merge these with the table args of its ancestors. In addition to, or alternatively, the ``merge_table_args`` method will also look for table arguments in a class attribute named ``__custom_table_args__``. >>> class Parent(BaseModel): >>> @orm.declared_attr.directive >>> @classmethod >>> def __table_args__(cls): >>> return {"schema": "my_schema"} >>> >>> class Child(Parent): >>> __custom_table_args__ = {"schema": "my_schema"} >>> >>> @orm.declared_attr.directive >>> @classmethod >>> def __table_args__(cls): >>> child_table_args = ( >>> Index("child_index_name", "column_1", "column_2"), >>> ) >>> return cls.merge_table_args(child_table_args) >>> >>> column_1: orm.Mapped[int] >>> column_2: orm.Mapped[int] Natural Keys ~~~~~~~~~~~~ Classes derived from this model must define a natural key. Natural keys are specified by explicitly setting ``hash=True`` on a ``mapped_column``. A natural key is intended to identify a set of attributes that uniquely identify a row in the table. The key will be used to define ``__hash__`` and ``__eq__`` for the class. Additionally, a non-unique index will be created on the natural key columns. Serializable to a Dictionary ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ There is currently a bug representing models using ``dataclasses.asdict`` when the model inherits from ``MappedAsDataclass`` and contains a relationship with ``back_populate`` defined. See: https://github.com/sqlalchemy/sqlalchemy/issues/9785 This class provides methods for converting the model to and from dictionaries. This has been tested for several common use cases, but may not be robust for more complex models. Data Transfer Objects ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Any class derived from this base class can build a corresponding DTO class type using the :meth:`to_dto_class` method. The created DTO class derives from ``pydantic.BaseModel``, which is convenient for offline use and data transfers, for example with FastAPI. In addition to mirroring the columns, composites, hybrid_properties, and relationships, the DTO object also provides a ``pretty_print`` method to format the string representation of the object. It takes one optional parameter ``line_length``. A DTO instance can be constructed from a :class:``BaseModel`` instance using either the `model_validate` classmethod method on the DTO or from the :meth:`to_dto` method of the ``BaseModel`` instance. To exclude an attribute (column, composite, hybrid, relationship, etc.) add ``dto=False` to the ``info`` dictionary of the attribute. Known Limitations ----------------- When converting to DTOs the following limitations apply: * A composite column must be a dataclass. * Only regular columns (e.g., no composite columns) can be used as natural keys. * Only the parent side of a relationship is added to the DTO. Namely, a reference to the child and children will be included in the parent, but a reference to the parent will not be included in the child or children. The parent is determined by the presence of foreign keys. For OtO and OtM relationships the parent is the model that does not contain a foreign key. For MtM relationships the parent is determined by looking at the first column of the association table. If all the foreign keys of that column are in the relationship's local columns then that model is considered the parent. This may cause problems if you directly select the child objects, because their parents will not be loaded into the DTO object. * In general you cannot convert a model to a dictionary or DTO and then back to the exact same model. The :meth:`from_dict` method does not handle cyclic relationships and will typically not be able to associate a parent instance from the child. """ # Note: In general SqlAlchemy collections like '__table__.columns' are # usually more like 'Mapping' or 'dict' like objects. However, when # iterated over, they iterate over the **values** not the keys. This is # why my previous code tended to work, but the linter would complain.
[docs] __abstract__ = True
[docs] metadata: MetaData = MetaData( # noqa SqlAlchemy naming_convention={ "pk": "pk_%(table_name)s", "ix": "ix_%(table_name)s_%(column_0_N_name)s", "uq": "uq_%(table_name)s_%(column_0_N_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_N_name)s_R_%(referred_table_name)s", } )
[docs] hasher = ObjectHasher(hash_size=64, signed=True)
"""The hasher used to calculate the has value of objects.""" _registered_dto_classes = dict() """A dictionary of all DTO classes that have been registered. A DTO is registered any time :meth:`to_dto_class` is called, which may recursively call :meth:`to_dto_class` for any relationships. """
[docs] def __hash__(self) -> int: # You can only define a custom hash function on a dataclass if # `eq=False` and `unsafe_hash=False`. return self.hasher.hash(self.natural_key)
[docs] def __eq__(self, other: Any) -> bool: if not isinstance(other, self.__class__): return False return self.natural_key == other.natural_key
@orm.declared_attr.directive
[docs] def __tablename__(cls): return camel_to_snake_case(cls.__name__)
@classmethod
[docs] def merge_table_args(cls, new_args: TableArgs = ()) -> TableArgs: """ This method is intended to be called from ``__table_args__`` when used as a ``@declared_attr``. It will merge the ``new_args`` with the arguments of its ancestors. You can also specify additional table arguments in the class variable ``__custom_table_args__``, which will also be merged. Parameters ---------- new_args : TableArgs Additional table arguments to be merged with the arguments of our ancestors. Returns ------- TableArgs The merged table arguments as a tuple. The first ``N`` arguments of the tuple contain positional arguments passed to the constructor of :class:`sa.Table``. If the last element is a dictionary, then it is the keyword arguments passed to the ``Table`` constructor, otherwise it is the final positional argument. Examples -------- >>> class Parent(BaseModel): >>> @orm.declared_attr.directive >>> @classmethod >>> def __table_args__(cls): >>> return {"schema": "my_schema"} >>> >>> class Child(Parent): >>> __custom_table_args__ = {"schema": "my_schema"} >>> >>> @orm.declared_attr.directive >>> @classmethod >>> def __table_args__(cls): >>> child_table_args = ( >>> Index("child_index_name", "column_1", "column_2"), >>> ) >>> return cls.merge_table_args(child_table_args) >>> >>> column_1: orm.Mapped[int] >>> column_2: orm.Mapped[int] """ # Adapted from another personal project and from: # https://github.com/sqlalchemy/sqlalchemy/discussions/8911#discussioncomment-6763269 # Get the __table__args from our immediate parents. # Each element of the accumulated list could be a dict or a tuple. bases = reversed(cls.__bases__) accumulated_args = [new_args] accumulated_args += [getattr(cls, "__custom_table_args__", ())] # Process the base classes in reverse order so that args are prioritized # from left (least priority) to right (highest priority). # Note, this will recurse if a base class calls 'merge_table_args' # in their '__table_args__' definition. accumulated_args += [getattr(b, "__table_args__", ()) for b in bases] # Remove empty arguments and reverse the order. 'new_args' have the # highest priority, then '__custom_args__', and finally parent # '__table_args__'. accumulated_args = reversed([a for a in accumulated_args if a]) # Keep track of positional (tuple) and kwargs (dict) arguments. # positional_args: set[Any] = set() # @notallshaw-gts suggest using a set for the positional arguments # to eliminate duplicates, but it's actually not clear to me what # positional arguments could reasonably be duplicates. positional_args: set[Any] = set() kwargs: dict[str, Any] = dict() # The items should be processed from root to leaf in the inheritance # hierarchy. Within a class, the '__table_args__' are processed first, # then the '__custom_table_args__', and finally the passed in # 'new_args'. # # They are in reverse priority, because items inserted earlier can be # overwritten by later entries. # # The output order will be unpredictable because the order of sets is # non-deterministic.. for current_args in accumulated_args: if isinstance(current_args, dict): kwargs |= current_args elif isinstance(current_args, tuple): last_arg = current_args[-1] if isinstance(last_arg, dict): kwargs |= last_arg current_args = current_args[:-1] positional_args |= set(current_args) else: ValueError(f"Table args must be a dict or tuple, not '{current_args}'") return tuple(positional_args) + (kwargs,)
@classmethod @cache @require(lambda table_name: table_name.count(".") < 2)
[docs] def find_by_table_name(cls, table_name: str) -> Optional[Type[BaseModel]]: """ Find a model derived from this class by its table name. Parameters ---------- table_name : str The name of the table whose model class you want to find. Returns ------- Type[FindByNameBase], optional Returns the model class if the table is found, otherwise ``None``. """ # See: https://stackoverflow.com/a/68862329 registry = getattr(cls, "registry") try: find_schema, find_table_name = table_name.split(".", 1) except ValueError: find_schema, find_table_name = "public", table_name for mapper in registry.mappers: model = mapper.class_ table = model.__table__ candidate_schema = table.schema or "public" candidate_table_name = model.__tablename__ if ( candidate_schema == find_schema and candidate_table_name == find_table_name ): return model return None
@property
[docs] def primary_key(self) -> tuple[tuple[str, Any], ...]: """ Get the primary key of this instance. The ``primary key`` is the set of ``key/value`` pairs corresponding to the configured primary key columns. For efficiency reasons, the result is returned as a tuple of tuples. """ names = self.primary_key_columns().keys() return tuple((n, getattr(self, n, None)) for n in names)
@property
[docs] def natural_key(self) -> tuple[tuple[str, Optional[Any]], ...]: """ Get the natural key of this instance. The ``natural key`` is the set of ``key/value`` pairs that uniquely identify this instance. The keys are the names of the columns returned by :meth:`natural_key_columns` and the values are the current value of the instance. For efficiency when used by the hash function the pairs are returned as a tuple of tuples. """ names = self.natural_key_columns().keys() return tuple((n, getattr(self, n, None)) for n in names)
@require( lambda self: all(len(a.columns) == 1 for a in self.__mapper__.column_attrs), "One of the mapped columns is a composite column.", )
[docs] def as_table_dict(self) -> dict[str, Any]: """ Returns the table data of the instance as a dictionary. ..warning:: This only works for simple models that have a one to one mapping from mapped columns to table columns. For example, it will break for models containing composite columns. Returns ------- dict[str, Any] The keys of the dictionary are the table column names and the values are the corresponding values from the instance. """ result = { a.columns[0].name: getattr(self, a.key, None) for a in self.__mapper__.column_attrs # noqa SqlAlchemy if a.columns[0].table is not None # Only include table columns } return result
[docs] def to_dict( self, exclude_attributes: Iterable[str] = (), drop_none: bool = False ) -> dict[str, Any]: """Covert the instance to a dictionary. This method represents the model as a dictionary as it is defined by the class mapping (as opposed to the table mapping). .. note:: As of 01/01/2025 calling `dataclasses.asdict` will raise a recursion error on models that have relationships with `back_populate` defined. .. note:: ``exclude_attributes`` does a simple string match on any attribute name in the model and any of its children. See: https://github.com/sqlalchemy/sqlalchemy/issues/9785 """ def _process_seen(model_: BaseModel) -> dict[str, Any]: result = {} result |= dict(model_.primary_key) result |= dict(model_.natural_key) return dict(sorted(result.items())) def _get_non_excluded_attributes(attrs: dict[str, Any]) -> list[str]: return [n for n in attrs if n not in exclude_attributes] def _process_regular_columns(model_: BaseModel) -> dict[str, Any]: attrs = _get_non_excluded_attributes(model_.regular_columns()) return {n: getattr(model_, n, None) for n in attrs} def _process_composite_columns(model_: BaseModel) -> dict[str, Any]: attrs = _get_non_excluded_attributes(model_.composite_columns()) result = {} for name in attrs: value = getattr(model_, name, None) result[name] = None if value is None else dataclasses.asdict(value) return result def _process_column_properties(model_: BaseModel) -> dict[str, Any]: attrs = _get_non_excluded_attributes(model_.column_properties()) return {n: getattr(model_, n, None) for n in attrs} def _process_hybrid_properties(model_: BaseModel) -> dict[str, Any]: hybrids = model_.hybrid_properties() attrs = _get_non_excluded_attributes(hybrids) return {n: hybrids[n].fget(model_) for n in attrs} def _process_relationships( model_: BaseModel, seen_: set[BaseModel] ) -> dict[str, Any]: relationships = model_.relationships() attrs = _get_non_excluded_attributes(relationships) result = {} for name in attrs: value = getattr(model_, name, None) if value is None: result[name] = None elif relationships[name].uselist: result[name] = [_do_to_dict(c, seen_) for c in value] else: result[name] = _do_to_dict(value, seen_) return result def _do_to_dict(model_: BaseModel, seen_: set[BaseModel]) -> dict[str, Any]: if model_ in seen_: return _process_seen(model_) seen_.add(model_) result = {} result |= _process_regular_columns(model_) result |= _process_composite_columns(model_) result |= _process_column_properties(model_) result |= _process_hybrid_properties(model_) result |= _process_relationships(model_, seen_) return remove_none_from_dict(result) if drop_none else result return _do_to_dict(self, set())
@classmethod
[docs] def from_dict(cls, data: dict) -> BaseModel: """ Creates an instance of the class from a dictionary representation of the data, allowing for nested structures and relationships between models. This method also ensures that instances with shared natural keys are not duplicated. Parameters ---------- data : dict A dictionary containing the attributes and nested relationships of the model. The keys should correspond to the field names of the model, and values should represent their corresponding data. For nested relationships, values are also expected to be dictionaries (for single relationships) or lists of dictionaries (for collections). Returns ------- BaseModel An instance of the BaseModel subclass created from the dictionary input. """ return cls._do_from_dict(data, seen := {})
@classmethod def _do_from_dict( cls: Type[Self], data: dict[str, Any], seen: dict[frozenset[tuple[str, Any]], BaseModel], ) -> Type[Self]: # region _do_from_dict Utility Methods def _make_hash_key() -> frozenset[tuple[str, Any]]: key = [("__class__", get_qualified_name(cls))] key += [(n, data[n]) for n in cls.natural_key_columns()] return frozenset(key) def _get_regular_column_init() -> dict[str, Any]: attrs = cls.regular_columns() columns = { k: a for k, f in cls.fields().items() if (a := attrs.get(k)) is not None and f.init is not False } kwargs = {} for name in columns: # We want to distinguish between a value explicitly set to # `None` and a missing value. Missing values should use their # corresponding `default` or `default_factory` during # construction, while an explicit `None` should be preserved. try: kwargs[name] = copy.deepcopy(data[name]) except KeyError: pass return kwargs def _get_composite_column_init() -> dict[str, Any]: attrs = cls.composite_columns() composites = { k: a for k, f in cls.fields().items() if (a := attrs.get(k)) is not None and f.init is not False } kwargs = {} for name, composite in composites.items(): # See note in _get_regular_column_init try: composite_kwargs = copy.deepcopy(data[name]) composite_class = composite.composite_class kwargs[name] = composite_class(**composite_kwargs) except KeyError: pass return kwargs def _get_relationship_init() -> dict[str, Any]: # For relationships, we are just creating dummy values during # initialization to prevent raising exceptions. It is very # difficult to prevent recursion errors if we try to recurse # here. Instead, we'll first construct the object without the # relationships so we can check if we've already an instance with # the same natural key. If so, we don't have to process any further. # Otherwise, we will fill in the relationship info (and recurse # if necessary) after the existence check. attrs = cls.relationships() relationships = { k: (f, r) for k, f in cls.fields().items() if (r := attrs.get(k)) is not None and f.init is not False } kwargs = {} for name, (field, relationship) in relationships.items(): if field.default is not None and not dataclasses.MISSING: default_value = field.default elif ( field.default_factory is not None and field.default_factory is not dataclasses.MISSING ): default_value = field.default_factory() elif relationship.uselist: default_value = relationship.collection_class([]) else: default_value = None kwargs[name] = default_value return kwargs def _update_regular_columns(instance_: BaseModel): attrs = cls.regular_columns() columns = { k: a for k, f in cls.fields().items() if (a := attrs.get(k)) is not None and f.init is False } for name in columns: try: setattr(instance_, name, copy.deepcopy(data[name])) except KeyError: pass def _update_column_properties(instance_: BaseModel): # Column properties are not evaluated until the model is queried # or associated with a session, so it will be `None` if we don't # set it here. attrs = cls.column_properties() columns = { k: a for k, f in cls.fields().items() if (a := attrs.get(k)) is not None and f.init is False } for name in columns: setattr(instance_, name, copy.deepcopy(data[name])) def _update_composite_columns(instance_: BaseModel): attrs = cls.composite_columns() composites = { k: a for k, f in cls.fields().items() if (a := attrs.get(k)) is not None and f.init is False } for name, composite in composites.items(): try: composite_kwargs = copy.deepcopy(data[name]) composite_class = composite.composite_class setattr(instance_, name, composite_class(**composite_kwargs)) except KeyError: pass return instance def _update_relationships(instance_: BaseModel): attrs = cls.relationships() fields = cls.fields().items() relationships = {k: r for k, f in fields if (r := attrs.get(k)) is not None} # Currently, this does not handel back populating parent relationships. # It should be possible, but is not trivial. In addition to keeping # track of entities seen by their natural key, you'd also have to # keep track of instances by foreign key columns. Then in the # 'KeyError' below, you could try to find the entity using the # foreign key value from the 'data' dictionary. for name, relationship in relationships.items(): try: values = data[name] except KeyError: pass else: rel_class = cast(Type[BaseModel], relationship.mapper.class_) if relationship.uselist: children = [rel_class._do_from_dict(v, seen) for v in values] children = relationship.collection_class(children) setattr(instance_, name, children) else: child = ( None if values is None else rel_class._do_from_dict(values, seen) ) setattr(instance_, name, child) # endregion _do_from_dict Utility Methods hash_key = _make_hash_key() if hash_key in seen: return seen[hash_key] init_kwargs = {} init_kwargs |= _get_regular_column_init() init_kwargs |= _get_composite_column_init() init_kwargs |= _get_relationship_init() # We should have all the info we need to construct an instance. # However, there may still be attributes not included in the __init__ # (e.g., autogenerated IDs) that need to be set below. seen[hash_key] = instance = cls(**init_kwargs) _update_regular_columns(instance) _update_column_properties(instance) _update_composite_columns(instance) _update_relationships(instance) return instance @classmethod
[docs] def dto_module(cls) -> ModuleType: """Get the module where the DTO class is defined.""" return importlib.import_module(cls.__module__)
@classmethod
[docs] def dto_module_name(cls) -> str: """Get the name of the module where the DTO class is defined.""" return f"{cls.__module__}"
@classmethod
[docs] def dto_class_name(cls) -> str: """Get the name of the DTO class.""" return f"{cls.__name__}BaseDto"
@classmethod
[docs] def dto_import_path(cls) -> str: """Get the import path of the DTO class.""" return f"{cls.dto_module_name()}.{cls.dto_class_name()}"
@classmethod
[docs] def dto_exclude_attributes(cls) -> set[str]: """A set of attribute names to exclude from the DTO representation.""" attrs = cls.__mapper__.all_orm_descriptors fields = cls.fields().items() def _exclude(p: orm.MapperProperty) -> bool: return any(p.info.get(key) is False for key in ("dto", "DTO")) return {n for n, f in fields if (a := attrs.get(n)) and _exclude(a)}
@classmethod @cache
[docs] def to_dto_class(cls) -> Type[pydantic.BaseModel]: """ Create a Data Transfer Object (DTO) class that corresponds with this model. The DTO class will have the same name as the model class but end with the prefix ``Dto``. So, if the model name is ``MyModel`` the DTO class will be named ``MyModelDto``. The DTO class inherits from the ``pydantic.BaseModel`` class and contains fields for all regular, composites, hybrid properties, and relationships of this model. ..note:: Regardless of how the SqlAlchemy model class is configured, all primary and foreign key columns are treated as optional and included in the ``__init__`` method with a default value of ``None``. The same applies to database generated integer columns. ..note:: To exclude an attribute from the DTO, set ``dto=False`` in the ``info`` dictionary of the attribute when defining the model. Returns ------- Type[pydantic.BaseModel] The DTO type. """ # In order to build the DTO, all models referenced in # relationships must also be available as DTOs. It is hard (maybe not # possible) to do this recursively. Instead, we'll use a queue of # unprocessed models (seeded with this one) and add the relationship # classes to the queue as we build. Once all referenced models have # corresponding DTOs we can rebuild our top level DTO to resolve the # forward references. # Pydantic appears to require each segment of a module path (i.e., each # package in the hierarchy) to be available in a namespace it searches. # However, adding the module or package into the global namespace # does not appear to work. As a workaround adding the values into a # custom namespace `types_namespace`, which can be passed to # `model_rebuild` does work. # The package and modules pydantic nees to initialize our DTO classes, # but are not available globally. types_namespace = {} # The DTO of this class that we'll return. dto = None # The queue of unprocessed models. queue = deque([cls]) while len(queue) > 0: model = queue.pop() types_namespace |= import_all_modules_in_path(model.dto_module_name()) result = model._do_dto_conversion(children := set()) queue.extend([c for c in children if c not in cls._registered_dto_classes]) # The first DTO is the one we want to return. # It would be more efficient to process this outside the loop, but # this is hardly a hot path and would be less readable and # maintainable. dto = dto or result # After creating the all the referenced DTOs we need to rebuild the # pydantic model to resolve the forward references. dto.model_rebuild(_types_namespace=types_namespace) return cast(type(pydantic.BaseModel), dto)
[docs] def to_dto(self) -> pydantic.BaseModel: """ Convert this model to a DTO instance. Returns ------- pydantic.BaseModel The equivalent DTO instance. Raises ------ pydantic.ValidationError If any of the required fields are not present. **Note**: this includes database generated fields, such as ``ids``. """ return self.to_dto_class().model_validate(self)
@classmethod
[docs] def from_dto(cls: Type[BaseModel], dto: pydantic.BaseModel) -> M: """Create a model instance from a DTO instance.""" return cls.from_dict(dto.model_dump())
[docs] def walk_children( self, callback: Callable[[B], None], traverse_viewonly: bool = True, ): """ A method to traverse the relationships of a given instance and apply a callback to each node in the traversal. Parameters ---------- callback : Callable[[DeclarativeBase], None] The function to call on each traversed node. traverse_viewonly: bool, default=True Whether to traverse viewonly relationships. """ queue = deque([self]) seen: set[BaseModel] = set() def _should_traverse(relationship_, related_) -> bool: """Return True if we should traverse the relationship.""" if related_ is None: return False elif relationship_.viewonly and not traverse_viewonly: return False return True while queue: current = queue.pop() # Prevent cycles if current in seen: continue seen.add(current) # Apply the callback to the current node callback(current) # Enqueue all related nodes for traversal for relationship in current.relationships().values(): related = getattr(current, relationship.key) if _should_traverse(relationship, related): related = [related] if not relationship.uselist else related queue.extend(related)
[docs] def copy(self: M) -> M: """ Return a deep copy of the instance that is not associated in any way with this instance. For example, the new instance is not added to a session when the original item is (which appears to happen if you use ``copy.deepcopy`` or ``dataclasses.replace``. Returns ------- BaseModel A copy of this instance. """ # Note, both copy.deepcopy(self) and dataclasses.replace(self) copy # SqlAlchemy metadata that cause unexpected behavior. return self.from_dict(copy.deepcopy(self.to_dict()))
@classmethod @cache
[docs] def class_name(cls) -> str: """Get the name of the class including the package and module prefix.""" return get_qualified_name(cls)
@classmethod @cache
[docs] def fields(cls: Type[Self]) -> dict[str, dataclasses.Field]: """ Return a mapping of the field names to their corresponding ``Field`` objects. Note, the dictionary is ordered by the field name. """ sorted_fields = sorted(dataclasses.fields(cls), key=lambda f: f.name) return {f.name: f for f in sorted_fields}
@classmethod @cache
[docs] def primary_key_columns(cls) -> dict[str, orm.ColumnProperty]: """Return the list of primary key columns of this class.""" attrs = cls.regular_columns() fields = cls.fields().items() pairs = [(f, a) for n, f in fields if (a := attrs.get(n)) is not None] return {f.name: a for f, a in pairs if any(c.primary_key for c in a.columns)}
@classmethod @cache
[docs] def natural_key_columns(cls) -> dict[str, orm.ColumnProperty]: """Return the list of natural key columns of this class.""" attrs = cls.regular_columns() fields = cls.fields().items() pairs = [(f, a) for n, f in fields if (a := attrs.get(n)) is not None] return {f.name: a for f, a in pairs if f.hash}
# It's less efficient, but we need to iterate through all the fields below # because the `mapper.attrs` collections map from **column** names, not # attribute names. @classmethod @cache
[docs] def regular_columns(cls: Type[Self]) -> dict[str, orm.ColumnProperty]: """ Return a mapping from an attribute name to its ``ColumnProperty`` instance for all regular columns of this class. A regular column is a column listed in ``Mapper.column_attrs``, maps to only one column that is an instance of `sa.Column`. """ fields = cls.fields().values() attrs = cls.__mapper__.column_attrs predicate = cls._is_regular_column return {f.name: a for f in fields if predicate(a := attrs.get(f.name))}
@classmethod @cache def _is_regular_column(cls, attr: orm.MapperProperty) -> bool: if not isinstance(attr, orm.ColumnProperty): return False columns = attr.columns return len(columns) == 1 and all(isinstance(c, sa.Column) for c in columns) @classmethod @cache
[docs] def column_properties(cls) -> dict[str, orm.ColumnProperty]: """ Return a mapping from an attribute name to its ``ColumnProperty`` instance for all ``column properties`` of this class. A column property is a mapped attribute configured using ``column_property``. """ fields = cls.fields() attrs = cls.__mapper__.column_attrs predicate = cls._is_column_property return {n: a for n in fields if predicate(a := attrs.get(n))}
@classmethod @cache def _is_column_property(cls, attr: orm.MapperProperty) -> bool: return isinstance(attr, orm.MappedSQLExpression) @classmethod @cache
[docs] def composite_columns(cls: Type[Self]) -> dict[str, orm.CompositeProperty]: """ Return a mapping from an attribute name to its ``CompositeProperty`` instance for all ``composite`` attributes of this class. A ``composte`` attribute is an attribute configured using ``sa.orm.composite``. """ fields = cls.fields() attrs = cls.__mapper__.attrs predicate = cls._is_composite return {n: a for n in fields if predicate(a := attrs.get(n))}
@classmethod @cache def _is_composite(cls, attr: orm.MapperProperty) -> bool: return isinstance(attr, orm.CompositeProperty) @classmethod @cache
[docs] def hybrid_properties(cls: Type[Self]) -> dict[str, hybrid_property]: """ Returns a mapping from an attribute name to a ``hybrid_property`` instance. A ``hybrid_property`` attribute is a getter style method annotated with the ``@hybrid_property`` decorator. """ attrs = cls.__mapper__.all_orm_descriptors predicate = cls._is_hybrid_property return {n: a for n, a in attrs.items() if predicate(a)}
@classmethod @cache def _is_hybrid_property(cls, attr: orm.MapperProperty) -> bool: return attr.extension_type is hybrid_property.extension_type @classmethod @cache
[docs] def relationships(cls) -> dict[str, orm.Relationship]: """ Returns a mapping from an attribute name to a ``Relationship`` for all relationships defined on this class. """ fields = cls.fields() attrs = cls.__mapper__.relationships predicate = cls._is_relationship return {n: a for n in fields if predicate(a := attrs.get(n))}
@classmethod @cache def _is_relationship(cls, attr: orm.MapperProperty) -> bool: return isinstance(attr, orm.Relationship) @classmethod @cache
[docs] def table_columns(cls) -> list[sa.Column]: """Get a list of the table column names for this model.""" return [c for c in cls.__table__.columns.values()]
@classmethod @cache @ensure(lambda result: result is not None)
[docs] def table_insertion_order(cls) -> dict[sa.Table, int]: """ The insertion priority for each table. A lower number has a higher priority and should be inserted before a table with a higher number. The priorities are determined using a topological sort of the dependency tree created by the relationships between models. Returns ------- dict[Table, int] A mapping from an SqlAlchemy table to its insertion priority (lower numbers indicate higher priority). """ return {t: i for i, t in enumerate(BaseModel.metadata.sorted_tables)}
@classmethod
[docs] def validate_mapper(cls): """Validate that this model is configured correctly.""" def _is_abstract(model_class: Type[BaseModel]) -> bool: return model_class.__abstract__ and not model_class.__tablename__ def _is_concrete(model_class: Type[BaseModel]) -> bool: return not _is_abstract(model_class) def _is_required(type_: Type[Any]) -> bool: return not is_optional(type_) def _has_any_default(field: dataclasses.Field) -> bool: default_missing = field.default == dataclasses.MISSING factory_missing = field.default_factory == dataclasses.MISSING return not (default_missing and factory_missing) def _has_value_default(field: dataclasses.Field) -> bool: missing = dataclasses.MISSING default_value = field.default factory_value = field.default_factory default_is_value = ( default_value is not missing and default_value is not None ) factory_is_value = ( factory_value is not missing and factory_value is not None ) return default_is_value or factory_is_value def _validate_natural_keys(model_class: Type[BaseModel]): model_name = model_class.__name__ attrs = model_class.natural_key_columns().values() if _is_concrete(model_class) and len(attrs) == 0: raise ValueError(f"'{model_name}' has no natural key columns.") cols = [c for a in attrs for c in a.columns] if any(c.nullable for c in cols): raise ValueError("Natural keys cannot be optional") def _validate_dto_configuration(model_class: Type[BaseModel]): dto_excludes = model_class.dto_exclude_attributes() fields = model_class.fields().values() excluded_fields = {f for f in fields if f.name in dto_excludes} for field in excluded_fields: # A field can be excluded if: # * It is optional, and default or default_factory is not missing, # or init is False. # * It is not optional, and `default` or default_factory is not # missing. if field.hash is True: raise ValueError("A DTO cannot exclude a natural key column.") if _is_required(field.type) and not _has_value_default(field): raise ValueError( "A DTO cannot exclude a required field that does not provide " "a non-None default or default_factory value." ) if is_optional(field.type): if _has_any_default(field): continue elif field.init is False: continue else: raise ValueError( "A DTO cannot exclude an optional field unless it is " "excluded from the 'init' method or has a default value or " "default_factory (even if it is None)." ) # We should be good otherwise def _add_natural_key_index(model_class: Type[BaseModel]): index_name = f"nk_{model_class.__tablename__}" index = Index(index_name, *model_class.natural_key_columns().values()) table = cast(Table, model_class.__table__) table.indexes.add(index) _validate_natural_keys(cls) _validate_dto_configuration(cls) _add_natural_key_index(cls)
# region Utility Methods @classmethod def _do_dto_conversion( cls: Type[Self], forward_declarations: set[Type[BaseModel]], ): """ Convert the SqlAlchemy model to a DTO implemented as a ``pydantic.BaseModel``. Parameters ---------- forward_declarations : set[str] A set of fully qualified (import path) class names of SqlAlchemy models that are referenced in relationships, but may not be defined yet. This set will be populated during the execution of this function. Returns ------- Type[pydantic.BaseModel] The DTO class. """ # region _do_dto_conversion utility functions class Config: """Pydantic DTO configuration.""" from_attributes = True def _is_auto_generated(attr: orm.MapperProperty) -> bool: columns = getattr(attr, "columns", []) is_int = any(c.type.python_type is int for c in columns) is_primary = any(c.primary_key for c in columns) try: has_default = any(c.default is not None for c in columns) except AttributeError: has_default = False return is_int and is_primary and not has_default def _is_foreign_key(attr: orm.MapperProperty) -> bool: return any(bool(c.foreign_keys) for c in getattr(attr, "columns", [])) def _to_dto_field( model_field: dataclasses.Field, model_attr: orm.MapperProperty ) -> pydantic.Field: """Create the pydantic field from the model field information.""" field_kwargs = {} if model_field.default is not dataclasses.MISSING: field_kwargs["default"] = model_field.default elif model_field.default_factory is not dataclasses.MISSING: field_kwargs["default_factory"] = model_field.default_factory elif _is_auto_generated(model_attr): field_kwargs["default"] = None elif _is_foreign_key(model_attr): field_kwargs["default"] = None if model_field.repr is not dataclasses.MISSING: field_kwargs["repr"] = model_field.repr return pydantic.Field(**field_kwargs) def _to_dto_type( model_field: dataclasses.Field, model_attr: orm.MapperProperty ) -> Type: """Create the pydantic type from the model field information.""" # The following attribute categories should always be optional in # the DTO class: # * autogenerated keys # * foreign keys # * column_property attributes # * hybrid_property attributes if _is_auto_generated(model_attr): return Optional[model_field.type] if _is_foreign_key(model_attr): return Optional[model_field.type] if cls._is_column_property(model_attr): return Optional[model_field.type] return model_field.type def _filter_excluded( attrs: dict[str, orm.MapperProperty | orm.Relationship], ) -> dict[str, orm.MapperProperty | orm.Relationship]: excluded = cls.dto_exclude_attributes() return {n: a for n, a in attrs.items() if n not in excluded} def _process_direct_attributes( attrs: dict[str, orm.MapperProperty] ) -> dict[str, tuple[Type, pydantic.Field]]: attrs = _filter_excluded(attrs) items = {n: (f, attrs[n]) for n, f in cls.fields().items() if n in attrs} return { name: (_to_dto_type(*args), _to_dto_field(*args)) for name, args in items.items() } def _process_regular_columns() -> dict[str, tuple[Type, pydantic.Field]]: return _process_direct_attributes(cls.regular_columns()) def _process_column_properties() -> dict[str, tuple[Type, pydantic.Field]]: return _process_direct_attributes(cls.column_properties()) def _process_composites() -> dict[str, tuple[Type, pydantic.Field]]: return _process_direct_attributes(cls.composite_columns()) def _process_hybrid_properties() -> dict[str, tuple[Type, pydantic.Field]]: attrs = _filter_excluded(cls.hybrid_properties()) result = {} for name, attr in attrs.items(): inner_type = get_type_hints(attr.fget).get("return", None) dto_type = Optional[inner_type] dto_field = pydantic.Field(init=True, default=None) result[name] = (dto_type, dto_field) return result def _is_parent_relationship(rel: orm.Relationship) -> bool: owns_fk = any(c.foreign_keys for c in rel.local_columns) is_primary_association = _is_primary_association(rel) return not owns_fk and is_primary_association def _is_primary_association(rel: orm.Relationship) -> bool: """True if the local column is the first entry in the association table.""" secondary = rel.secondary # This is only applicable to MtM relationships with an # association table. if secondary is None: return True local_columns = rel.local_columns foreign_keys = secondary.columns[0].foreign_keys return all(fk.column in local_columns for fk in foreign_keys) def _process_relationships() -> dict[str, tuple[Type, pydantic.Field]]: fields = cls.fields() relationships = _filter_excluded(cls.relationships()) pairs = {n: (fields[n], r) for n, r in relationships.items()} pairs = {n: p for n, p in pairs.items() if _is_parent_relationship(p[1])} result = {} for name, (model_field, relationship) in pairs.items(): collection_class = relationship.collection_class dto_field = _to_dto_field(model_field, relationship) model_type = relationship.entity.class_ # Add the relationship class to the set of additional classes # we may need to process later. forward_declarations.add(model_type) # We need to use a forward declaration of the relationship # DTO type, otherwise we'll run into cyclic recursion issues. dto_entity_type = model_type.dto_import_path() if relationship.uselist: dto_type = collection_class[dto_entity_type] # noqa else: dto_type = dto_entity_type result[name] = (dto_type, dto_field) return result # region DTO methods # These methods will be attached to the DTO class def _dto_pretty_print(self, line_length=1): # Pretty print the DTO # See: https://github.com/pydantic/pydantic/discussions/7787#discussioncomment-9658140 import black s = repr(self) return black.format_str(s, mode=black.FileMode(line_length=line_length)) def _dto_natural_keys(self) -> list[str]: # Return the list of natural key attribute names return list(cls.natural_key_columns().keys()) def _dto_hash(self) -> int: value = tuple((a, getattr(self, a, None)) for a in self.natural_keys()) return self.hasher.hash(value) # endregion DTO methods def _make_dto_class(attributes: dict[str, tuple[Type, pydantic.Field]]) -> type: exclude = cls.dto_exclude_attributes() attributes = {k: v for k, v in attributes.items() if k not in exclude} dto_types = {k: v[0] for k, v in attributes.items()} dto_fields = {k: v[1] for k, v in attributes.items()} # Enable creating a DTO from a dict dto_types["Config"] = ClassVar[Type] # Add a hasher dto_types["hasher"] = ClassVar[ObjectHasher] return type( cls.dto_class_name(), (pydantic.BaseModel,), { "__module__": cls.dto_module_name(), "__annotations__": dto_types, "__hash__": _dto_hash, "Config": Config, "hasher": cls.hasher, "natural_keys": _dto_natural_keys, "pretty_print": _dto_pretty_print, **dto_fields, }, ) # endregion _do_dto_conversion utility functions dto_attributes = {} dto_attributes |= _process_regular_columns() dto_attributes |= _process_column_properties() dto_attributes |= _process_composites() dto_attributes |= _process_hybrid_properties() dto_attributes |= _process_relationships() dto_class = _make_dto_class(dto_attributes) dto_class = cast(DtoModel, dto_class) # Register the DTO class with the BaseModel BaseModel._registered_dto_classes[cls] = dto_class # Add the class to the appropriate module setattr(cls.dto_module(), cls.dto_class_name(), dto_class) return dto_class
# endregion Utility Methods @dataclass(frozen=True, kw_only=False)
[docs] class SqlBinaryExpression(YamlConfig): """A class that represents the basic binary expression for an SQL column."""
[docs] column: str
""" The column name. """
[docs] operator: SqlOperator
""" The operator to compare the ``column`` and ``value`` with. """
[docs] value: Any
""" The value used as a comparison. """
[docs] def __call__(self, model_or_table: Type[B] | sa.Table) -> sa.BinaryExpression: return self.to_expression(model_or_table)
@require( lambda self, model_or_table: ( self.column in model_or_table.columns if isinstance(model_or_table, sa.Table) else hasattr(model_or_table, self.column) ), "Invalid column", ) @require( lambda model_or_table: isinstance(model_or_table, sa.Table) or issubclass(model_or_table, orm.DeclarativeBase), ( "The 'model_or_table' must either be an SqlAlchemy Table or an SqlAlchemy " "ORM model (subclass of DeclarativeBase)." ), )
[docs] def to_expression(self, model_or_table: Type[B] | sa.Table) -> sa.BinaryExpression: """Return a clause that can be used with an SqlAlchemy ``where`` statement. Parameters ---------- model_or_table : sqlalchemy.Table The table object that contains the column. Returns ------- BinaryExpression The corresponding SqlAlchemy binary expression. """ column = ( model_or_table.c[self.column] if isinstance(model_or_table, sa.Table) else getattr(model_or_table, self.column) ) operator = self.operator if operator == "==": return column == self.value # noqa if operator == "==" and self.value is None: return column.is_(None) if operator == "!=": return column != self.value # noqa if operator == "!=" and self.value is None: return column.isnot(None) if operator == ">": return column > self.value # noqa if operator == ">=": return column >= self.value # noqa if operator == "<": return column < self.value # noqa if operator == "<=": return column <= self.value # noqa if operator == "like": return column.like(self.value) if operator == "in": return column.in_(self.value) if operator == "not_in": return column.not_in(self.value) if operator == "is_null": return column.is_(None) if operator == "is_not_null": return column.isnot(None) else: raise ValueError(f"Invalid operator: {operator}")
@dataclass(kw_only=False, frozen=True)
[docs] class SqlSelectionCriteria(YamlConfig): """A class that represents a conjunction of SqlBinaryExpression."""
[docs] expressions: list[SqlBinaryExpression]
""" The list of binary expressions that will be used to filter the query. """ @require( lambda model_or_table: isinstance(model_or_table, sa.Table) or issubclass(model_or_table, orm.DeclarativeBase), ( "The 'model_or_table' must either be an SqlAlchemy Table or an SqlAlchemy " "ORM model (subclass of DeclarativeBase)." ), )
[docs] def to_conjunction( self, model_or_table: Type[B] | sa.Table ) -> sa.BinaryExpression | sa.BooleanClauseList: """ Return a conjunction of binary expressions that can be used with an SqlAlchemy ``where`` statement. Returns ------- BinaryExpression | BooleanClauseList """ table = ( model_or_table if isinstance(model_or_table, sa.Table) else model_or_table.__table__ ) return sa.and_(*[e.to_expression(table) for e in self.expressions])
@dataclass(kw_only=False, frozen=True)
[docs] class SqlOrderExpression(YamlConfig): """A class that represents a basic order expression for an SQL column."""
[docs] column: str
"""The name of the column to sort by."""
[docs] ascending: bool = True
"""Whether to sort the column in ascending order."""
[docs] def to_expression(self, model_or_table: Type[B] | sa.Table) -> sa.UnaryExpression: """Convert the order expression to a SqlAlchemy ``UnaryExpression``.""" column = ( model_or_table.c[self.column] if isinstance(model_or_table, sa.Table) else getattr(model_or_table, self.column) ) return column.asc() if self.ascending else column.desc()
@dataclass(kw_only=False, frozen=True)
[docs] class SqlOrderCriteria(YamlConfig): """A class that represents a list of ``SqlOrderExpressions``."""
[docs] expressions: list[SqlOrderExpression]
"""The list of order expressions."""
[docs] def to_criteria( self, model_or_table: Type[B] | sa.Table ) -> list[sa.UnaryExpression]: """ Convert the order criteria to a list of SqlAlchemy ``UnaryExpressions`` that can be used with an SqlAlchemy ``order_by`` statement.""" return [e.to_expression(model_or_table) for e in self.expressions]
@dataclass(kw_only=True)
[docs] class DatabaseConnectionConfig(abc.ABC): """The options necessary to connect to a database using SqlAlchemy."""
[docs] database: str
"""The name of the database or the path to the database file if using sqlite3."""
[docs] drivername: str = "sqlite"
"""The name of the driver used to connect to the database."""
[docs] username: Optional[str] = None
"""The user name to use when connecting to the database, if needed."""
[docs] password: Optional[str] = None
"""The password to use when connecting to the database, if needed."""
[docs] host: Optional[str] = None
"""The database host, if applicable."""
[docs] port: Optional[int] = None
"""The database port, if applicable."""
[docs] use_async: bool = False
"""Whether to use async mode.""" @property
[docs] def url(self) -> sa.URL: """Return the :class:`sqlalchemy.URL` representation of this class.""" return sa.URL.create( drivername=self.drivername, username=self.username, password=self.password, host=self.host, port=self.port, database=self.database, )
[docs] class Repository: """A class implementing the basic CRUD operations for the data access layer.""" def __init__(self, engine: AlchemyEngine, model_class: Type[M]):
[docs] self.engine = engine
[docs] self.model_class = model_class
[docs] def insert( self, instances: BaseModel | Iterable[BaseModel], session: Optional[Session] = None, ): """ Insert one or more model instances into the database. Parameters ---------- instances : BaseModel | Iterable[BaseModel] The data to insert. session : Session | AsyncSession, optional A session to use. Raises ------ Exception If there is a problem adding the data to the database. """ instances = self._normalize_data(instances) with self._get_sync_session(session) as local_session: local_session.add_all(instances)
[docs] async def async_insert( self, instances: BaseModel | Iterable[BaseModel], session: Optional[AsyncSession] = None, ): """ Insert one or more model instances into the database. Parameters ---------- instances : BaseModel | Iterable[BaseModel] The data to insert. session : Session | AsyncSession, optional A session to use. Raises ------ Exception If there is a problem adding the data to the database. """ instances = self._normalize_data(instances) async with self._get_async_session(session) as local_session: local_session.add_all(instances)
[docs] def upsert( self, instances: BaseModel | Iterable[BaseModel], on_conflict: ConflictResolutionStrategy = "do_nothing", session: Optional[Session] = None, ): """ Upsert one or more model instances into the database. ..note:: This only works with dialects that support INSERT ... ON CONFLICT. This should include ``postgresql``, ``mysql/mariadb``, and ``sqlite``. However, currently only ``postgresql`` is supported. ..note:: This method creates several copies of the data, so you should be careful about memory management if you are inserting a large number of objects. ..warning:: This only works if all non-null fields are provided **including primary and foreign keys.** ..warning:: This is not well tested with complex model configurations such as hybrid properties and column properties. Parameters ---------- instances on_conflict session """ statements = self._make_upsert_statements(instances, on_conflict) for stmt in statements: with self._get_sync_session(session) as local_session: local_session.execute(stmt)
[docs] async def async_upsert( self, instances: BaseModel | Iterable[BaseModel], on_conflict: ConflictResolutionStrategy = "do_nothing", session: Optional[AsyncSession] = None, ): """ Upsert one or more model instances into the database. ..note:: This only works with dialects that support INSERT ... ON CONFLICT. This should include ``postgresql``, ``mysql/mariadb``, and ``sqlite``. However, currently only ``postgresql`` is supported. ..note:: This method creates several copies of the data, so you should be careful about memory management if you are inserting a large number of objects. ..warning:: This only works if all non-null fields are provided **including primary and foreign keys.** ..warning:: This is not well tested with complex model configurations such as hybrid properties and column properties. Parameters ---------- instances on_conflict session """ statements = self._make_upsert_statements(instances, on_conflict) for stmt in statements: with self._get_async_session(session) as local_session: await local_session.execute(stmt)
@require( lambda filter_by: all(isinstance(f, SqlBinaryExpression) for f in filter_by) ) @require(lambda order_by: all(isinstance(o, SqlOrderExpression) for o in order_by)) @require(lambda limit: limit is None or isinstance(limit, int))
[docs] def find_all( self, filter_by: Iterable[SqlBinaryExpression] = (), order_by: Iterable[SqlOrderExpression] = (), limit: Optional[int] = None, session: Optional[Session] = None, ) -> list[DtoModel]: """ Finds and retrieves multiple models from the database based on specified filters, ordering, and a limit. Converts the retrieved models to their corresponding DTO (Data Transfer Object) representation. Parameters ---------- filter_by : Iterable[SqlBinaryExpression], optional Collection of SQL binary expressions defining the conditions for filtering the query. Default is an empty iterable. order_by : Iterable[SqlOrderExpression], optional Collection of SQL order expressions defining the sorting order for the query. Default is an empty iterable. limit : int, optional Maximum number of records to retrieve. If None, no limit is applied. Default is None. session : Session, optional An optional database session instance. If not provided, a new session will be created internally for executing the query. Returns ------- list[DtoModel] A list of DTO instances that correspond to the retrieved database models. """ stmt = self._select_stmt(self.model_class, filter_by, order_by, limit) with self._get_sync_session(session) as local_session: models = local_session.scalars(stmt).all() dtos = [m.to_dto() for m in models] return dtos
[docs] async def async_find_all( self, filter_by: Iterable[SqlBinaryExpression] = (), order_by: Iterable[SqlOrderExpression] = (), limit: Optional[int] = None, session: Optional[Session] = None, ) -> list[DtoModel]: """ Asynchronously, finds and retrieves multiple models from the database based on specified filters, ordering, and a limit. Converts the retrieved models to their corresponding DTO (Data Transfer Object) representation. Parameters ---------- filter_by : Iterable[SqlBinaryExpression], optional Collection of SQL binary expressions defining the conditions for filtering the query. Default is an empty iterable. order_by : Iterable[SqlOrderExpression], optional Collection of SQL order expressions defining the sorting order for the query. Default is an empty iterable. limit : int, optional Maximum number of records to retrieve. If None, no limit is applied. Default is None. session : Session, optional An optional database session instance. If not provided, a new session will be created internally for executing the query. Returns ------- list[DtoModel] A list of DTO instances that correspond to the retrieved database models. """ stmt = self._select_stmt(self.model_class, filter_by, order_by, limit) async with self._get_async_session(session) as local_session: models = (await local_session.scalars(stmt)).all() dtos = [m.to_dto() for m in models] return dtos
[docs] def find_one( self, filter_by: Iterable[SqlBinaryExpression] = (), session: Optional[Session] = None, raise_on_none: bool = True, ) -> DtoModel: """ Retrieve a single record from the database that matches the provided filter criteria. Converts the obtained database model into a data transfer object (DTO). An exception is raised when more than one result is returned or no results are found when `raise_on_none` is True. Parameters ---------- filter_by : Iterable[SqlBinaryExpression], optional Filter conditions to apply when querying the database. Defaults to an empty tuple. session : Optional[Session], optional SQLAlchemy session to use for querying. If not provided, a new session will be created and used for the operation. raise_on_none : bool, optional Specifies whether to raise an exception if no records match the specified filter criteria. Defaults to True. Returns ------- DtoModel The retrieved data model in DTO form if a record is found; otherwise, None when `raise_on_none` is set to False. """ # 'find_one' doesn't need an 'order_by', because it will be an exception # if more than one result is found by the query. stmt = self._select_stmt(self.model_class, filter_by, ()) with self._get_sync_session(session) as local_session: result = local_session.scalars(stmt) model = result.one() if raise_on_none else result.one_or_none() dto = model.to_dto() if model is not None else None return dto
[docs] async def async_find_one( self, filter_by: Iterable[SqlBinaryExpression] = (), session: Optional[Session] = None, raise_on_none: bool = True, ) -> DtoModel: """ Asynchronously, retrieve a single record from the database that matches the provided filter criteria. Converts the obtained database model into a data transfer object (DTO). An exception is raised when more than one result is returned or no results are found when `raise_on_none` is True. Parameters ---------- filter_by : Iterable[SqlBinaryExpression], optional Filter conditions to apply when querying the database. Defaults to an empty tuple. session : Optional[Session], optional SQLAlchemy session to use for querying. If not provided, a new session will be created and used for the operation. raise_on_none : bool, optional Specifies whether to raise an exception if no records match the specified filter criteria. Defaults to True. Returns ------- DtoModel The retrieved data model in DTO form if a record is found; otherwise, None when `raise_on_none` is set to False. """ stmt = self._select_stmt(self.model_class, filter_by, ()) async with self._get_async_session(session) as local_session: result = await local_session.scalars(stmt) model = result.one() if raise_on_none else result.one_or_none() dto = model.to_dto() if model is not None else None return dto
@classmethod def _select_stmt( cls, model_class: Type[M], filter_by: Iterable[SqlBinaryExpression], order_by: Iterable[SqlOrderExpression], limit: Optional[int] = None, ): filter_criteria = SqlSelectionCriteria(list(filter_by)) order_criteria = SqlOrderCriteria(list(order_by)) return ( sa.select(model_class) .where(filter_criteria.to_conjunction(model_class)) .order_by(*order_criteria.to_criteria(model_class)) .limit(limit) ) @classmethod def _normalize_data( cls, data: BaseModel | Iterable[BaseModel] ) -> Iterable[BaseModel]: if not is_iterable(data): data = [data] return data @contextmanager def _get_sync_session(self, session: Optional[Session]) -> Session: if session is None: with self.engine.session_scope() as engine_session: yield engine_session else: yield session @asynccontextmanager async def _get_async_session(self, session: Optional[AsyncSession]) -> AsyncSession: if session is None: async with self.engine.async_session_scope() as engine_session: yield engine_session else: yield session # region Upsert Utilities @classmethod def _group_as_dict( cls, data: Iterable[BaseModel] ) -> dict[Type[B], list[dict[str, Any]]]: result: dict[Type[B], list[dict[str, Any]]] = defaultdict(list) def add_to_result(n: BaseModel): """Convert the node to a dictionary and add it to the result.""" table = cast(sa.Table, n.__table__) result[n.__class__].append( {c.key: getattr(n, c.key, None) for c in n.table_columns()} ) for item in data: item.walk_children(add_to_result, traverse_viewonly=False) return result @classmethod def _filter_duplicate_values( cls, data: dict[Type[B], list[dict[str, Any]]] ) -> dict[Type[B], list[dict[str, Any]]]: # Convert each dictionary to a set of tuples, which can be inserted # into a set and have duplicates removed. tmp = {model: {set(r.items()) for r in rows} for model, rows in data.items()} # Convert the tuple representation back to dictionaries. return {model: [dict(t) for t in tuples] for model, tuples in tmp.items()} @classmethod def _sort_by_insertion_order( cls, data: dict[Type[B], list[dict[str, Any]]] ) -> dict[Type[B], list[dict[str, Any]]]: order_lookup = BaseModel.table_insertion_order() return { k: v for k, v in sorted(data.items(), key=lambda e: order_lookup[e[0].__table__]) } def _make_upsert_statement( self, model_class: Type[B], model_values: list[dict[str, Any]], on_conflict: Optional[ConflictResolutionStrategy], ) -> Insert: def _get_insert_function(): if self.engine.dialect == "postgresql": return postgres_insert elif self.engine.dialect == "mysql": return mysql_insert elif self.engine.dialect == "sqlite": return sqlite_insert else: log.warning( f"The dialect '{self.engine.dialect}' does not support upserts. " f"Performing a regular insert instead." ) return default_insert insert = _get_insert_function() pk_columns = model_class.primary_key_column_names() # The base insert stmt = insert(model_class).values(model_values) if on_conflict == "do_nothing": stmt = stmt.on_conflict_do_nothing(index_elements=pk_columns) elif on_conflict == "do_update": stmt = stmt.on_conflict_do_update( index_elements=pk_columns, set_={ c.name: c for c in stmt.excluded # noqa SqlAlchemy if c.name not in pk_columns }, ) return stmt def _make_upsert_statements( self, data: BaseModel | Iterable[BaseModel], on_conflict: ConflictResolutionStrategy, ) -> list[Insert]: data = self._normalize_data(data) grouped_as_dict = self._group_as_dict(data) grouped_as_dict = self._filter_duplicate_values(grouped_as_dict) grouped_as_dict = self._sort_by_insertion_order(grouped_as_dict) return [ self._make_upsert_statement(m, v, on_conflict) for m, v in grouped_as_dict.items() ]