"""Utilities for working with SqlAlchemy."""
# Future Library
from __future__ import annotations
import abc
import copy
import dataclasses
import importlib
import logging
from collections import defaultdict, deque
from contextlib import asynccontextmanager, contextmanager
from functools import cache
from types import ModuleType
from typing import (
Any,
Callable,
ClassVar,
Iterable,
Literal,
Optional,
Self,
Type,
TypeVar,
cast,
dataclass_transform,
get_type_hints,
)
import pydantic
import sqlalchemy as sa
from icontract import ensure, require
from pydantic.dataclasses import dataclass
from sqlalchemy import (
Engine,
Index,
Insert,
MetaData,
Table,
create_engine,
event,
orm,
)
from sqlalchemy import insert as default_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as postgres_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import DeclarativeBase, Session
from rwskit.collections_ import is_iterable, remove_none_from_dict
from rwskit.config import YamlConfig
from rwskit.hash import ObjectHasher
from rwskit.strings_ import camel_to_snake_case
from rwskit.types_ import get_qualified_name, import_all_modules_in_path, is_optional
[docs]
log = logging.getLogger(__name__)
[docs]
B = TypeVar("B", bound=DeclarativeBase)
"""A type extending :class:`~sqlalchemy.orm.DeclarativeBase`."""
[docs]
M = TypeVar("M", bound="BaseModel")
[docs]
TableArgs = dict[str, Any] | tuple[Any, ...]
"""The type of the ``__table_args__`` attribute on a :class:`~sqlalchemy.Table`."""
[docs]
DtoModel = Type[pydantic.BaseModel]
"""The base type for a DTO model."""
[docs]
SqlOperator = Literal[
"==", "!=", ">", ">=", "<", "<=", "like", "in", "not_in", "is_null", "is_not_null"
]
"""The supported SQL operators for use in an :class:`SqlBinaryExpression`."""
[docs]
ConflictResolutionStrategy = Literal["do_nothing", "update"]
[docs]
Walkable = orm.DeclarativeBase | orm.Mapper
[docs]
WalkPredicate = Callable[[orm.DeclarativeBase, orm.DeclarativeBase], bool]
[docs]
class AlchemyEngine:
"""A wrapper around :class:`sqlalchemy.engine.Engine` that provides additional functionality.""" # noqa
@require(
lambda engine_or_config, base_model: isinstance(
engine_or_config, (AsyncEngine, Engine, DatabaseConnectionConfig)
)
and issubclass(base_model, orm.DeclarativeBase),
)
def __init__(
self,
engine_or_config: AsyncEngine | Engine | DatabaseConnectionConfig,
base_model: Type[orm.DeclarativeBase],
**engine_kwargs: Any,
):
[docs]
self.base_model = base_model
self._engine = self._configure_engine(engine_or_config, **engine_kwargs)
self._is_async = isinstance(self._engine, AsyncEngine)
self._async_session_factory = self._configure_async_session_factory()
self._sync_session_factory = self._configure_sync_session_factory()
self._dialect = self._engine.dialect.name
@property
[docs]
def dialect(self) -> str:
"""Returns the dialect name of the engine."""
return self._dialect
@property
[docs]
def supports_async(self) -> bool:
"""Returns true if the engine supports async mode."""
return self._is_async
@property
[docs]
def sync_engine(self) -> Engine:
"""Get a reference to a synchronous :class:`sqlalchemy.engine.Engine`."""
return self._engine.sync_engine if self.supports_async else self._engine
@property
[docs]
def async_engine(self) -> AsyncEngine:
"""
Get a reference to an asynchronous :class:`sqlalchemy.ext.asyncio.AsyncEngine`.
Returns
-------
AsyncEngine
The engine.
Raises
------
ValueError
If the engine does not support async mode.
"""
if self.supports_async:
return self._engine
raise ValueError("Async engine is not supported.")
[docs]
def make_session(self) -> sa.Session:
"""
Create a session from the :attr:`AlchemyEngine.session_factory`.
..note::
You probably want to use :meth:`AlchemyEngine.session_scope`,
but there may be cases where this is more appropriate.
Returns
-------
Session
A :class:`~sqlalchemy.orm.Session` object.
"""
session = self._sync_session_factory()
return session
[docs]
def make_connection(self) -> sa.Connection:
"""
Get a connection to the engine.
Returns
-------
Connection
An :class:`sqlalchemy.Connection`.
"""
connection = self.sync_engine.connect()
return connection
[docs]
def make_raw_connection(self) -> sa.PoolProxiedConnection:
"""
Get a raw connection to the engine.
.. note::
See `Working with Driver SQL and Raw DBAPI Connections <https://docs.sqlalchemy.org/en/20/core/connections.html#working-with-driver-sql-and-raw-dbapi-connections>`_
in the SqlAlchemy documentation for the difference between regular
and raw connections.
Returns
-------
PoolProxiedConnection
A :class:`sqlalchemy.PoolProxiedConnection`.
"""
return self.sync_engine.raw_connection()
@contextmanager
[docs]
def session_scope(self, expire_on_commit: bool = False) -> sa.Session:
"""
A context manager for committing successful transactions when the
session is complete or rolling back if there was an exception.
Parameters
----------
expire_on_commit : bool, default=False
If ``True`` model objects will be marked as stale when the
next commit. This will invalidate all relationships and raise
an exception if they are accessed outside the session.
Returns
-------
Session
An :class:`sqlalchemy.orm.Session`.
"""
session = self._sync_session_factory()
session.expire_on_commit = expire_on_commit
try:
yield session
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
@asynccontextmanager
[docs]
async def async_session_scope(self, expire_on_commit: bool = False) -> sa.Session:
"""
A context manager for committing successful transactions when the
session is complete or rolling back if there was an exception.
Parameters
----------
expire_on_commit : bool, default=False
If ``True`` model objects will be marked as stale when the
next commit. This will invalidate all relationships and raise
an exception if they are accessed outside the session.
Returns
-------
Session
An :class:`sqlalchemy.orm.Session`.
"""
session = self._async_session_factory()
session.expire_on_commit = expire_on_commit
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()
@contextmanager
[docs]
def test_scope(self, expire_on_commit: bool = False) -> sa.Session:
"""
A session scope for testing.
This session will always roll back after exiting the context manager
and should not persist any changes to the database.
Parameters
----------
expire_on_commit : bool, default=False
If ``True`` model objects will be marked as stale when the
next commit. This will invalidate all relationships and raise
an exception if they are accessed outside the session.
Returns
-------
Session
An :class:`sqlalchemy.orm.Session`.
Raises
------
RuntimeError
If the user tries to commit changes during the session.
"""
def raise_on_commit():
"""Raise an exception if the session is committed."""
raise RuntimeError(
"Session.commit() is not allowed inside a test scope session."
)
session = self._sync_session_factory()
session.expire_on_commit = expire_on_commit
session.commit = raise_on_commit
try:
yield session
finally:
session.rollback()
session.close()
@asynccontextmanager
[docs]
async def async_test_scope(self, expire_on_commit: bool = False) -> sa.Session:
"""
A session scope for testing.
This session will always roll back after exiting the context manager
and should not persist any changes to the database.
Parameters
----------
expire_on_commit : bool, default=False
If ``True`` model objects will be marked as stale when the
next commit. This will invalidate all relationships and raise
an exception if they are accessed outside the session.
Returns
-------
Session
An :class:`sqlalchemy.orm.Session`.
Raises
------
RuntimeError
If the user tries to commit changes during the session.
"""
def raise_on_commit():
"""Raise an exception if the session is committed."""
raise RuntimeError(
"Session.commit() is not allowed inside a test scope session."
)
session = self._async_session_factory()
session.expire_on_commit = expire_on_commit
session.commit = raise_on_commit
try:
yield session
finally:
await session.rollback()
await session.close()
@classmethod
def _configure_engine(
cls,
engine_or_config: Engine | AsyncEngine | DatabaseConnectionConfig,
**engine_kwargs: Any,
) -> AsyncEngine | Engine:
if isinstance(engine_or_config, (Engine, AsyncEngine)):
return engine_or_config
create = create_async_engine if engine_or_config.use_async else create_engine
return create(engine_or_config.url, **engine_kwargs)
def _configure_sync_session_factory(self) -> orm.sessionmaker:
return orm.sessionmaker(bind=self.sync_engine)
def _configure_async_session_factory(self) -> Optional[async_sessionmaker]:
if self.supports_async:
return async_sessionmaker(bind=self.async_engine)
return None
# Note, you can't use 'once' here, because it will literally only run the
# listener once, not once per mapper, which is required.
@event.listens_for(orm.Mapper, "mapper_configured")
@dataclass_transform(kw_only_default=True, eq_default=True, unsafe_hash_default=True)
[docs]
class BaseModel(
orm.DeclarativeBase,
orm.MappedAsDataclass,
kw_only=True,
eq=False,
unsafe_hash=False,
):
"""A base class for creating declarative SqlAlchemy models.
Features
--------
Table Lookup
~~~~~~~~~~~~
Find any model derived from this base class by their table name.
Merging Table Args
~~~~~~~~~~~~~~~~~~
SqlAlchemy does not merge ``__table_args__`` during inheritance. For
example, if you have a base class that will set the schema for all
child classes, it will not work if the child class defines its own
``__table_args__`` (e.g., to create a multi-column index). This base
class provides a function to merge the ``__table_args__`` of parents
with their children.
This funcitonality is enabled by defining ``__table_args__`` as a
``@declared_attr.directive`` on the class and returning the value of
:meth:`merge_table_args`. ``merge_table_args`` accepts one optional
parameter, which can be a tuple or dictionary (the expected types of
``__table_args__``) and will merge these with the table args of its
ancestors. In addition to, or alternatively, the ``merge_table_args``
method will also look for table arguments in a class attribute named
``__custom_table_args__``.
>>> class Parent(BaseModel):
>>> @orm.declared_attr.directive
>>> @classmethod
>>> def __table_args__(cls):
>>> return {"schema": "my_schema"}
>>>
>>> class Child(Parent):
>>> __custom_table_args__ = {"schema": "my_schema"}
>>>
>>> @orm.declared_attr.directive
>>> @classmethod
>>> def __table_args__(cls):
>>> child_table_args = (
>>> Index("child_index_name", "column_1", "column_2"),
>>> )
>>> return cls.merge_table_args(child_table_args)
>>>
>>> column_1: orm.Mapped[int]
>>> column_2: orm.Mapped[int]
Natural Keys
~~~~~~~~~~~~
Classes derived from this model must define a natural key.
Natural keys are specified by explicitly setting ``hash=True`` on a
``mapped_column``. A natural key is intended to
identify a set of attributes that uniquely identify a row in the table.
The key will be used to define ``__hash__`` and ``__eq__`` for the class.
Additionally, a non-unique index will be created on the natural key columns.
Serializable to a Dictionary
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
There is currently a bug representing models using ``dataclasses.asdict``
when the model inherits from ``MappedAsDataclass`` and contains a
relationship with ``back_populate`` defined.
See: https://github.com/sqlalchemy/sqlalchemy/issues/9785
This class provides methods for converting the model to and from
dictionaries. This has been tested for several common use cases, but
may not be robust for more complex models.
Data Transfer Objects
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Any class derived from this base class can build a corresponding DTO
class type using the :meth:`to_dto_class` method. The created DTO class
derives from ``pydantic.BaseModel``, which is convenient for offline use
and data transfers, for example with FastAPI.
In addition to mirroring the columns, composites, hybrid_properties, and
relationships, the DTO object also provides a ``pretty_print`` method to
format the string representation of the object. It takes one optional
parameter ``line_length``.
A DTO instance can be constructed from a :class:``BaseModel`` instance
using either the `model_validate` classmethod method on the DTO or
from the :meth:`to_dto` method of the ``BaseModel`` instance.
To exclude an attribute (column, composite, hybrid, relationship, etc.)
add ``dto=False` to the ``info`` dictionary of the attribute.
Known Limitations
-----------------
When converting to DTOs the following limitations apply:
* A composite column must be a dataclass.
* Only regular columns (e.g., no composite columns) can be used as natural
keys.
* Only the parent side of a relationship is added to the DTO. Namely, a
reference to the child and children will be included in the parent, but
a reference to the parent will not be included in the child or children.
The parent is determined by the presence of foreign keys. For OtO and OtM
relationships the parent is the model that does not contain a foreign
key. For MtM relationships the parent is determined by looking at the
first column of the association table. If all the foreign keys of
that column are in the relationship's local columns then that model
is considered the parent. This may cause problems if you directly select
the child objects, because their parents will not be loaded into the
DTO object.
* In general you cannot convert a model to a dictionary or DTO and then
back to the exact same model. The :meth:`from_dict` method does not
handle cyclic relationships and will typically not be able to associate
a parent instance from the child.
"""
# Note: In general SqlAlchemy collections like '__table__.columns' are
# usually more like 'Mapping' or 'dict' like objects. However, when
# iterated over, they iterate over the **values** not the keys. This is
# why my previous code tended to work, but the linter would complain.
[docs]
hasher = ObjectHasher(hash_size=64, signed=True)
"""The hasher used to calculate the has value of objects."""
_registered_dto_classes = dict()
"""A dictionary of all DTO classes that have been registered.
A DTO is registered any time :meth:`to_dto_class` is called, which may
recursively call :meth:`to_dto_class` for any relationships.
"""
[docs]
def __hash__(self) -> int:
# You can only define a custom hash function on a dataclass if
# `eq=False` and `unsafe_hash=False`.
return self.hasher.hash(self.natural_key)
[docs]
def __eq__(self, other: Any) -> bool:
if not isinstance(other, self.__class__):
return False
return self.natural_key == other.natural_key
@orm.declared_attr.directive
[docs]
def __tablename__(cls):
return camel_to_snake_case(cls.__name__)
@classmethod
[docs]
def merge_table_args(cls, new_args: TableArgs = ()) -> TableArgs:
"""
This method is intended to be called from ``__table_args__`` when used
as a ``@declared_attr``. It will merge the ``new_args`` with the
arguments of its ancestors. You can also specify additional table
arguments in the class variable ``__custom_table_args__``, which will
also be merged.
Parameters
----------
new_args : TableArgs
Additional table arguments to be merged with the arguments of our ancestors.
Returns
-------
TableArgs
The merged table arguments as a tuple. The first ``N`` arguments
of the tuple contain positional arguments passed to the constructor
of :class:`sa.Table``. If the last element is a dictionary, then it
is the keyword arguments passed to the ``Table`` constructor,
otherwise it is the final positional argument.
Examples
--------
>>> class Parent(BaseModel):
>>> @orm.declared_attr.directive
>>> @classmethod
>>> def __table_args__(cls):
>>> return {"schema": "my_schema"}
>>>
>>> class Child(Parent):
>>> __custom_table_args__ = {"schema": "my_schema"}
>>>
>>> @orm.declared_attr.directive
>>> @classmethod
>>> def __table_args__(cls):
>>> child_table_args = (
>>> Index("child_index_name", "column_1", "column_2"),
>>> )
>>> return cls.merge_table_args(child_table_args)
>>>
>>> column_1: orm.Mapped[int]
>>> column_2: orm.Mapped[int]
"""
# Adapted from another personal project and from:
# https://github.com/sqlalchemy/sqlalchemy/discussions/8911#discussioncomment-6763269
# Get the __table__args from our immediate parents.
# Each element of the accumulated list could be a dict or a tuple.
bases = reversed(cls.__bases__)
accumulated_args = [new_args]
accumulated_args += [getattr(cls, "__custom_table_args__", ())]
# Process the base classes in reverse order so that args are prioritized
# from left (least priority) to right (highest priority).
# Note, this will recurse if a base class calls 'merge_table_args'
# in their '__table_args__' definition.
accumulated_args += [getattr(b, "__table_args__", ()) for b in bases]
# Remove empty arguments and reverse the order. 'new_args' have the
# highest priority, then '__custom_args__', and finally parent
# '__table_args__'.
accumulated_args = reversed([a for a in accumulated_args if a])
# Keep track of positional (tuple) and kwargs (dict) arguments.
# positional_args: set[Any] = set()
# @notallshaw-gts suggest using a set for the positional arguments
# to eliminate duplicates, but it's actually not clear to me what
# positional arguments could reasonably be duplicates.
positional_args: set[Any] = set()
kwargs: dict[str, Any] = dict()
# The items should be processed from root to leaf in the inheritance
# hierarchy. Within a class, the '__table_args__' are processed first,
# then the '__custom_table_args__', and finally the passed in
# 'new_args'.
#
# They are in reverse priority, because items inserted earlier can be
# overwritten by later entries.
#
# The output order will be unpredictable because the order of sets is
# non-deterministic..
for current_args in accumulated_args:
if isinstance(current_args, dict):
kwargs |= current_args
elif isinstance(current_args, tuple):
last_arg = current_args[-1]
if isinstance(last_arg, dict):
kwargs |= last_arg
current_args = current_args[:-1]
positional_args |= set(current_args)
else:
ValueError(f"Table args must be a dict or tuple, not '{current_args}'")
return tuple(positional_args) + (kwargs,)
@classmethod
@cache
@require(lambda table_name: table_name.count(".") < 2)
[docs]
def find_by_table_name(cls, table_name: str) -> Optional[Type[BaseModel]]:
"""
Find a model derived from this class by its table name.
Parameters
----------
table_name : str
The name of the table whose model class you want to find.
Returns
-------
Type[FindByNameBase], optional
Returns the model class if the table is found, otherwise ``None``.
"""
# See: https://stackoverflow.com/a/68862329
registry = getattr(cls, "registry")
try:
find_schema, find_table_name = table_name.split(".", 1)
except ValueError:
find_schema, find_table_name = "public", table_name
for mapper in registry.mappers:
model = mapper.class_
table = model.__table__
candidate_schema = table.schema or "public"
candidate_table_name = model.__tablename__
if (
candidate_schema == find_schema
and candidate_table_name == find_table_name
):
return model
return None
@property
[docs]
def primary_key(self) -> tuple[tuple[str, Any], ...]:
"""
Get the primary key of this instance.
The ``primary key`` is the set of ``key/value`` pairs corresponding
to the configured primary key columns. For efficiency reasons, the
result is returned as a tuple of tuples.
"""
names = self.primary_key_columns().keys()
return tuple((n, getattr(self, n, None)) for n in names)
@property
[docs]
def natural_key(self) -> tuple[tuple[str, Optional[Any]], ...]:
"""
Get the natural key of this instance.
The ``natural key`` is the set of ``key/value`` pairs that uniquely
identify this instance. The keys are the names of the columns returned
by :meth:`natural_key_columns` and the values are the current value
of the instance. For efficiency when used by the hash function the
pairs are returned as a tuple of tuples.
"""
names = self.natural_key_columns().keys()
return tuple((n, getattr(self, n, None)) for n in names)
@require(
lambda self: all(len(a.columns) == 1 for a in self.__mapper__.column_attrs),
"One of the mapped columns is a composite column.",
)
[docs]
def as_table_dict(self) -> dict[str, Any]:
"""
Returns the table data of the instance as a dictionary.
..warning::
This only works for simple models that have a one to one
mapping from mapped columns to table columns. For example,
it will break for models containing composite columns.
Returns
-------
dict[str, Any]
The keys of the dictionary are the table column names and the
values are the corresponding values from the instance.
"""
result = {
a.columns[0].name: getattr(self, a.key, None)
for a in self.__mapper__.column_attrs # noqa SqlAlchemy
if a.columns[0].table is not None # Only include table columns
}
return result
[docs]
def to_dict(
self, exclude_attributes: Iterable[str] = (), drop_none: bool = False
) -> dict[str, Any]:
"""Covert the instance to a dictionary.
This method represents the model as a dictionary as it is defined
by the class mapping (as opposed to the table mapping).
.. note::
As of 01/01/2025 calling `dataclasses.asdict` will raise a
recursion error on models that have relationships with
`back_populate` defined.
.. note::
``exclude_attributes`` does a simple string match on any attribute
name in the model and any of its children.
See: https://github.com/sqlalchemy/sqlalchemy/issues/9785
"""
def _process_seen(model_: BaseModel) -> dict[str, Any]:
result = {}
result |= dict(model_.primary_key)
result |= dict(model_.natural_key)
return dict(sorted(result.items()))
def _get_non_excluded_attributes(attrs: dict[str, Any]) -> list[str]:
return [n for n in attrs if n not in exclude_attributes]
def _process_regular_columns(model_: BaseModel) -> dict[str, Any]:
attrs = _get_non_excluded_attributes(model_.regular_columns())
return {n: getattr(model_, n, None) for n in attrs}
def _process_composite_columns(model_: BaseModel) -> dict[str, Any]:
attrs = _get_non_excluded_attributes(model_.composite_columns())
result = {}
for name in attrs:
value = getattr(model_, name, None)
result[name] = None if value is None else dataclasses.asdict(value)
return result
def _process_column_properties(model_: BaseModel) -> dict[str, Any]:
attrs = _get_non_excluded_attributes(model_.column_properties())
return {n: getattr(model_, n, None) for n in attrs}
def _process_hybrid_properties(model_: BaseModel) -> dict[str, Any]:
hybrids = model_.hybrid_properties()
attrs = _get_non_excluded_attributes(hybrids)
return {n: hybrids[n].fget(model_) for n in attrs}
def _process_relationships(
model_: BaseModel, seen_: set[BaseModel]
) -> dict[str, Any]:
relationships = model_.relationships()
attrs = _get_non_excluded_attributes(relationships)
result = {}
for name in attrs:
value = getattr(model_, name, None)
if value is None:
result[name] = None
elif relationships[name].uselist:
result[name] = [_do_to_dict(c, seen_) for c in value]
else:
result[name] = _do_to_dict(value, seen_)
return result
def _do_to_dict(model_: BaseModel, seen_: set[BaseModel]) -> dict[str, Any]:
if model_ in seen_:
return _process_seen(model_)
seen_.add(model_)
result = {}
result |= _process_regular_columns(model_)
result |= _process_composite_columns(model_)
result |= _process_column_properties(model_)
result |= _process_hybrid_properties(model_)
result |= _process_relationships(model_, seen_)
return remove_none_from_dict(result) if drop_none else result
return _do_to_dict(self, set())
@classmethod
[docs]
def from_dict(cls, data: dict) -> BaseModel:
"""
Creates an instance of the class from a dictionary representation of the data,
allowing for nested structures and relationships between models. This method
also ensures that instances with shared natural keys are not duplicated.
Parameters
----------
data : dict
A dictionary containing the attributes and nested relationships of the
model. The keys should correspond to the field names of the model, and
values should represent their corresponding data. For nested relationships,
values are also expected to be dictionaries (for single relationships)
or lists of dictionaries (for collections).
Returns
-------
BaseModel
An instance of the BaseModel subclass created from the dictionary input.
"""
return cls._do_from_dict(data, seen := {})
@classmethod
def _do_from_dict(
cls: Type[Self],
data: dict[str, Any],
seen: dict[frozenset[tuple[str, Any]], BaseModel],
) -> Type[Self]:
# region _do_from_dict Utility Methods
def _make_hash_key() -> frozenset[tuple[str, Any]]:
key = [("__class__", get_qualified_name(cls))]
key += [(n, data[n]) for n in cls.natural_key_columns()]
return frozenset(key)
def _get_regular_column_init() -> dict[str, Any]:
attrs = cls.regular_columns()
columns = {
k: a
for k, f in cls.fields().items()
if (a := attrs.get(k)) is not None and f.init is not False
}
kwargs = {}
for name in columns:
# We want to distinguish between a value explicitly set to
# `None` and a missing value. Missing values should use their
# corresponding `default` or `default_factory` during
# construction, while an explicit `None` should be preserved.
try:
kwargs[name] = copy.deepcopy(data[name])
except KeyError:
pass
return kwargs
def _get_composite_column_init() -> dict[str, Any]:
attrs = cls.composite_columns()
composites = {
k: a
for k, f in cls.fields().items()
if (a := attrs.get(k)) is not None and f.init is not False
}
kwargs = {}
for name, composite in composites.items():
# See note in _get_regular_column_init
try:
composite_kwargs = copy.deepcopy(data[name])
composite_class = composite.composite_class
kwargs[name] = composite_class(**composite_kwargs)
except KeyError:
pass
return kwargs
def _get_relationship_init() -> dict[str, Any]:
# For relationships, we are just creating dummy values during
# initialization to prevent raising exceptions. It is very
# difficult to prevent recursion errors if we try to recurse
# here. Instead, we'll first construct the object without the
# relationships so we can check if we've already an instance with
# the same natural key. If so, we don't have to process any further.
# Otherwise, we will fill in the relationship info (and recurse
# if necessary) after the existence check.
attrs = cls.relationships()
relationships = {
k: (f, r)
for k, f in cls.fields().items()
if (r := attrs.get(k)) is not None and f.init is not False
}
kwargs = {}
for name, (field, relationship) in relationships.items():
if field.default is not None and not dataclasses.MISSING:
default_value = field.default
elif (
field.default_factory is not None
and field.default_factory is not dataclasses.MISSING
):
default_value = field.default_factory()
elif relationship.uselist:
default_value = relationship.collection_class([])
else:
default_value = None
kwargs[name] = default_value
return kwargs
def _update_regular_columns(instance_: BaseModel):
attrs = cls.regular_columns()
columns = {
k: a
for k, f in cls.fields().items()
if (a := attrs.get(k)) is not None and f.init is False
}
for name in columns:
try:
setattr(instance_, name, copy.deepcopy(data[name]))
except KeyError:
pass
def _update_column_properties(instance_: BaseModel):
# Column properties are not evaluated until the model is queried
# or associated with a session, so it will be `None` if we don't
# set it here.
attrs = cls.column_properties()
columns = {
k: a
for k, f in cls.fields().items()
if (a := attrs.get(k)) is not None and f.init is False
}
for name in columns:
setattr(instance_, name, copy.deepcopy(data[name]))
def _update_composite_columns(instance_: BaseModel):
attrs = cls.composite_columns()
composites = {
k: a
for k, f in cls.fields().items()
if (a := attrs.get(k)) is not None and f.init is False
}
for name, composite in composites.items():
try:
composite_kwargs = copy.deepcopy(data[name])
composite_class = composite.composite_class
setattr(instance_, name, composite_class(**composite_kwargs))
except KeyError:
pass
return instance
def _update_relationships(instance_: BaseModel):
attrs = cls.relationships()
fields = cls.fields().items()
relationships = {k: r for k, f in fields if (r := attrs.get(k)) is not None}
# Currently, this does not handel back populating parent relationships.
# It should be possible, but is not trivial. In addition to keeping
# track of entities seen by their natural key, you'd also have to
# keep track of instances by foreign key columns. Then in the
# 'KeyError' below, you could try to find the entity using the
# foreign key value from the 'data' dictionary.
for name, relationship in relationships.items():
try:
values = data[name]
except KeyError:
pass
else:
rel_class = cast(Type[BaseModel], relationship.mapper.class_)
if relationship.uselist:
children = [rel_class._do_from_dict(v, seen) for v in values]
children = relationship.collection_class(children)
setattr(instance_, name, children)
else:
child = (
None
if values is None
else rel_class._do_from_dict(values, seen)
)
setattr(instance_, name, child)
# endregion _do_from_dict Utility Methods
hash_key = _make_hash_key()
if hash_key in seen:
return seen[hash_key]
init_kwargs = {}
init_kwargs |= _get_regular_column_init()
init_kwargs |= _get_composite_column_init()
init_kwargs |= _get_relationship_init()
# We should have all the info we need to construct an instance.
# However, there may still be attributes not included in the __init__
# (e.g., autogenerated IDs) that need to be set below.
seen[hash_key] = instance = cls(**init_kwargs)
_update_regular_columns(instance)
_update_column_properties(instance)
_update_composite_columns(instance)
_update_relationships(instance)
return instance
@classmethod
[docs]
def dto_module(cls) -> ModuleType:
"""Get the module where the DTO class is defined."""
return importlib.import_module(cls.__module__)
@classmethod
[docs]
def dto_module_name(cls) -> str:
"""Get the name of the module where the DTO class is defined."""
return f"{cls.__module__}"
@classmethod
[docs]
def dto_class_name(cls) -> str:
"""Get the name of the DTO class."""
return f"{cls.__name__}BaseDto"
@classmethod
[docs]
def dto_import_path(cls) -> str:
"""Get the import path of the DTO class."""
return f"{cls.dto_module_name()}.{cls.dto_class_name()}"
@classmethod
[docs]
def dto_exclude_attributes(cls) -> set[str]:
"""A set of attribute names to exclude from the DTO representation."""
attrs = cls.__mapper__.all_orm_descriptors
fields = cls.fields().items()
def _exclude(p: orm.MapperProperty) -> bool:
return any(p.info.get(key) is False for key in ("dto", "DTO"))
return {n for n, f in fields if (a := attrs.get(n)) and _exclude(a)}
@classmethod
@cache
[docs]
def to_dto_class(cls) -> Type[pydantic.BaseModel]:
"""
Create a Data Transfer Object (DTO) class that corresponds with this model.
The DTO class will have the same name as the model class but end
with the prefix ``Dto``. So, if the model name is ``MyModel`` the DTO
class will be named ``MyModelDto``.
The DTO class inherits from the ``pydantic.BaseModel`` class and
contains fields for all regular, composites, hybrid properties, and
relationships of this model.
..note::
Regardless of how the SqlAlchemy model class is configured,
all primary and foreign key columns are treated as optional
and included in the ``__init__`` method with a default value of
``None``. The same applies to database generated integer columns.
..note::
To exclude an attribute from the DTO, set ``dto=False`` in the
``info`` dictionary of the attribute when defining the model.
Returns
-------
Type[pydantic.BaseModel]
The DTO type.
"""
# In order to build the DTO, all models referenced in
# relationships must also be available as DTOs. It is hard (maybe not
# possible) to do this recursively. Instead, we'll use a queue of
# unprocessed models (seeded with this one) and add the relationship
# classes to the queue as we build. Once all referenced models have
# corresponding DTOs we can rebuild our top level DTO to resolve the
# forward references.
# Pydantic appears to require each segment of a module path (i.e., each
# package in the hierarchy) to be available in a namespace it searches.
# However, adding the module or package into the global namespace
# does not appear to work. As a workaround adding the values into a
# custom namespace `types_namespace`, which can be passed to
# `model_rebuild` does work.
# The package and modules pydantic nees to initialize our DTO classes,
# but are not available globally.
types_namespace = {}
# The DTO of this class that we'll return.
dto = None
# The queue of unprocessed models.
queue = deque([cls])
while len(queue) > 0:
model = queue.pop()
types_namespace |= import_all_modules_in_path(model.dto_module_name())
result = model._do_dto_conversion(children := set())
queue.extend([c for c in children if c not in cls._registered_dto_classes])
# The first DTO is the one we want to return.
# It would be more efficient to process this outside the loop, but
# this is hardly a hot path and would be less readable and
# maintainable.
dto = dto or result
# After creating the all the referenced DTOs we need to rebuild the
# pydantic model to resolve the forward references.
dto.model_rebuild(_types_namespace=types_namespace)
return cast(type(pydantic.BaseModel), dto)
[docs]
def to_dto(self) -> pydantic.BaseModel:
"""
Convert this model to a DTO instance.
Returns
-------
pydantic.BaseModel
The equivalent DTO instance.
Raises
------
pydantic.ValidationError
If any of the required fields are not present. **Note**: this
includes database generated fields, such as ``ids``.
"""
return self.to_dto_class().model_validate(self)
@classmethod
[docs]
def from_dto(cls: Type[BaseModel], dto: pydantic.BaseModel) -> M:
"""Create a model instance from a DTO instance."""
return cls.from_dict(dto.model_dump())
[docs]
def walk_children(
self,
callback: Callable[[B], None],
traverse_viewonly: bool = True,
):
"""
A method to traverse the relationships of a given instance and apply a
callback to each node in the traversal.
Parameters
----------
callback : Callable[[DeclarativeBase], None]
The function to call on each traversed node.
traverse_viewonly: bool, default=True
Whether to traverse viewonly relationships.
"""
queue = deque([self])
seen: set[BaseModel] = set()
def _should_traverse(relationship_, related_) -> bool:
"""Return True if we should traverse the relationship."""
if related_ is None:
return False
elif relationship_.viewonly and not traverse_viewonly:
return False
return True
while queue:
current = queue.pop()
# Prevent cycles
if current in seen:
continue
seen.add(current)
# Apply the callback to the current node
callback(current)
# Enqueue all related nodes for traversal
for relationship in current.relationships().values():
related = getattr(current, relationship.key)
if _should_traverse(relationship, related):
related = [related] if not relationship.uselist else related
queue.extend(related)
[docs]
def copy(self: M) -> M:
"""
Return a deep copy of the instance that is not associated in any
way with this instance. For example, the new instance is not added
to a session when the original item is (which appears to happen if
you use ``copy.deepcopy`` or ``dataclasses.replace``.
Returns
-------
BaseModel
A copy of this instance.
"""
# Note, both copy.deepcopy(self) and dataclasses.replace(self) copy
# SqlAlchemy metadata that cause unexpected behavior.
return self.from_dict(copy.deepcopy(self.to_dict()))
@classmethod
@cache
[docs]
def class_name(cls) -> str:
"""Get the name of the class including the package and module prefix."""
return get_qualified_name(cls)
@classmethod
@cache
[docs]
def fields(cls: Type[Self]) -> dict[str, dataclasses.Field]:
"""
Return a mapping of the field names to their corresponding ``Field``
objects. Note, the dictionary is ordered by the field name.
"""
sorted_fields = sorted(dataclasses.fields(cls), key=lambda f: f.name)
return {f.name: f for f in sorted_fields}
@classmethod
@cache
[docs]
def primary_key_columns(cls) -> dict[str, orm.ColumnProperty]:
"""Return the list of primary key columns of this class."""
attrs = cls.regular_columns()
fields = cls.fields().items()
pairs = [(f, a) for n, f in fields if (a := attrs.get(n)) is not None]
return {f.name: a for f, a in pairs if any(c.primary_key for c in a.columns)}
@classmethod
@cache
[docs]
def natural_key_columns(cls) -> dict[str, orm.ColumnProperty]:
"""Return the list of natural key columns of this class."""
attrs = cls.regular_columns()
fields = cls.fields().items()
pairs = [(f, a) for n, f in fields if (a := attrs.get(n)) is not None]
return {f.name: a for f, a in pairs if f.hash}
# It's less efficient, but we need to iterate through all the fields below
# because the `mapper.attrs` collections map from **column** names, not
# attribute names.
@classmethod
@cache
[docs]
def regular_columns(cls: Type[Self]) -> dict[str, orm.ColumnProperty]:
"""
Return a mapping from an attribute name to its ``ColumnProperty``
instance for all regular columns of this class. A regular column is a
column listed in ``Mapper.column_attrs``, maps to only one column that
is an instance of `sa.Column`.
"""
fields = cls.fields().values()
attrs = cls.__mapper__.column_attrs
predicate = cls._is_regular_column
return {f.name: a for f in fields if predicate(a := attrs.get(f.name))}
@classmethod
@cache
def _is_regular_column(cls, attr: orm.MapperProperty) -> bool:
if not isinstance(attr, orm.ColumnProperty):
return False
columns = attr.columns
return len(columns) == 1 and all(isinstance(c, sa.Column) for c in columns)
@classmethod
@cache
[docs]
def column_properties(cls) -> dict[str, orm.ColumnProperty]:
"""
Return a mapping from an attribute name to its ``ColumnProperty``
instance for all ``column properties`` of this class. A column property
is a mapped attribute configured using ``column_property``.
"""
fields = cls.fields()
attrs = cls.__mapper__.column_attrs
predicate = cls._is_column_property
return {n: a for n in fields if predicate(a := attrs.get(n))}
@classmethod
@cache
def _is_column_property(cls, attr: orm.MapperProperty) -> bool:
return isinstance(attr, orm.MappedSQLExpression)
@classmethod
@cache
[docs]
def composite_columns(cls: Type[Self]) -> dict[str, orm.CompositeProperty]:
"""
Return a mapping from an attribute name to its ``CompositeProperty``
instance for all ``composite`` attributes of this class. A ``composte``
attribute is an attribute configured using ``sa.orm.composite``.
"""
fields = cls.fields()
attrs = cls.__mapper__.attrs
predicate = cls._is_composite
return {n: a for n in fields if predicate(a := attrs.get(n))}
@classmethod
@cache
def _is_composite(cls, attr: orm.MapperProperty) -> bool:
return isinstance(attr, orm.CompositeProperty)
@classmethod
@cache
[docs]
def hybrid_properties(cls: Type[Self]) -> dict[str, hybrid_property]:
"""
Returns a mapping from an attribute name to a ``hybrid_property``
instance. A ``hybrid_property`` attribute is a getter style method
annotated with the ``@hybrid_property`` decorator.
"""
attrs = cls.__mapper__.all_orm_descriptors
predicate = cls._is_hybrid_property
return {n: a for n, a in attrs.items() if predicate(a)}
@classmethod
@cache
def _is_hybrid_property(cls, attr: orm.MapperProperty) -> bool:
return attr.extension_type is hybrid_property.extension_type
@classmethod
@cache
[docs]
def relationships(cls) -> dict[str, orm.Relationship]:
"""
Returns a mapping from an attribute name to a ``Relationship`` for
all relationships defined on this class.
"""
fields = cls.fields()
attrs = cls.__mapper__.relationships
predicate = cls._is_relationship
return {n: a for n in fields if predicate(a := attrs.get(n))}
@classmethod
@cache
def _is_relationship(cls, attr: orm.MapperProperty) -> bool:
return isinstance(attr, orm.Relationship)
@classmethod
@cache
[docs]
def table_columns(cls) -> list[sa.Column]:
"""Get a list of the table column names for this model."""
return [c for c in cls.__table__.columns.values()]
@classmethod
@cache
@ensure(lambda result: result is not None)
[docs]
def table_insertion_order(cls) -> dict[sa.Table, int]:
"""
The insertion priority for each table. A lower number has a higher
priority and should be inserted before a table with a higher number.
The priorities are determined using a topological sort of the
dependency tree created by the relationships between models.
Returns
-------
dict[Table, int]
A mapping from an SqlAlchemy table to its insertion priority (lower
numbers indicate higher priority).
"""
return {t: i for i, t in enumerate(BaseModel.metadata.sorted_tables)}
@classmethod
[docs]
def validate_mapper(cls):
"""Validate that this model is configured correctly."""
def _is_abstract(model_class: Type[BaseModel]) -> bool:
return model_class.__abstract__ and not model_class.__tablename__
def _is_concrete(model_class: Type[BaseModel]) -> bool:
return not _is_abstract(model_class)
def _is_required(type_: Type[Any]) -> bool:
return not is_optional(type_)
def _has_any_default(field: dataclasses.Field) -> bool:
default_missing = field.default == dataclasses.MISSING
factory_missing = field.default_factory == dataclasses.MISSING
return not (default_missing and factory_missing)
def _has_value_default(field: dataclasses.Field) -> bool:
missing = dataclasses.MISSING
default_value = field.default
factory_value = field.default_factory
default_is_value = (
default_value is not missing and default_value is not None
)
factory_is_value = (
factory_value is not missing and factory_value is not None
)
return default_is_value or factory_is_value
def _validate_natural_keys(model_class: Type[BaseModel]):
model_name = model_class.__name__
attrs = model_class.natural_key_columns().values()
if _is_concrete(model_class) and len(attrs) == 0:
raise ValueError(f"'{model_name}' has no natural key columns.")
cols = [c for a in attrs for c in a.columns]
if any(c.nullable for c in cols):
raise ValueError("Natural keys cannot be optional")
def _validate_dto_configuration(model_class: Type[BaseModel]):
dto_excludes = model_class.dto_exclude_attributes()
fields = model_class.fields().values()
excluded_fields = {f for f in fields if f.name in dto_excludes}
for field in excluded_fields:
# A field can be excluded if:
# * It is optional, and default or default_factory is not missing,
# or init is False.
# * It is not optional, and `default` or default_factory is not
# missing.
if field.hash is True:
raise ValueError("A DTO cannot exclude a natural key column.")
if _is_required(field.type) and not _has_value_default(field):
raise ValueError(
"A DTO cannot exclude a required field that does not provide "
"a non-None default or default_factory value."
)
if is_optional(field.type):
if _has_any_default(field):
continue
elif field.init is False:
continue
else:
raise ValueError(
"A DTO cannot exclude an optional field unless it is "
"excluded from the 'init' method or has a default value or "
"default_factory (even if it is None)."
)
# We should be good otherwise
def _add_natural_key_index(model_class: Type[BaseModel]):
index_name = f"nk_{model_class.__tablename__}"
index = Index(index_name, *model_class.natural_key_columns().values())
table = cast(Table, model_class.__table__)
table.indexes.add(index)
_validate_natural_keys(cls)
_validate_dto_configuration(cls)
_add_natural_key_index(cls)
# region Utility Methods
@classmethod
def _do_dto_conversion(
cls: Type[Self],
forward_declarations: set[Type[BaseModel]],
):
"""
Convert the SqlAlchemy model to a DTO implemented as a
``pydantic.BaseModel``.
Parameters
----------
forward_declarations : set[str]
A set of fully qualified (import path) class names of SqlAlchemy
models that are referenced in relationships, but may not be defined
yet. This set will be populated during the execution of this
function.
Returns
-------
Type[pydantic.BaseModel]
The DTO class.
"""
# region _do_dto_conversion utility functions
class Config:
"""Pydantic DTO configuration."""
from_attributes = True
def _is_auto_generated(attr: orm.MapperProperty) -> bool:
columns = getattr(attr, "columns", [])
is_int = any(c.type.python_type is int for c in columns)
is_primary = any(c.primary_key for c in columns)
try:
has_default = any(c.default is not None for c in columns)
except AttributeError:
has_default = False
return is_int and is_primary and not has_default
def _is_foreign_key(attr: orm.MapperProperty) -> bool:
return any(bool(c.foreign_keys) for c in getattr(attr, "columns", []))
def _to_dto_field(
model_field: dataclasses.Field, model_attr: orm.MapperProperty
) -> pydantic.Field:
"""Create the pydantic field from the model field information."""
field_kwargs = {}
if model_field.default is not dataclasses.MISSING:
field_kwargs["default"] = model_field.default
elif model_field.default_factory is not dataclasses.MISSING:
field_kwargs["default_factory"] = model_field.default_factory
elif _is_auto_generated(model_attr):
field_kwargs["default"] = None
elif _is_foreign_key(model_attr):
field_kwargs["default"] = None
if model_field.repr is not dataclasses.MISSING:
field_kwargs["repr"] = model_field.repr
return pydantic.Field(**field_kwargs)
def _to_dto_type(
model_field: dataclasses.Field, model_attr: orm.MapperProperty
) -> Type:
"""Create the pydantic type from the model field information."""
# The following attribute categories should always be optional in
# the DTO class:
# * autogenerated keys
# * foreign keys
# * column_property attributes
# * hybrid_property attributes
if _is_auto_generated(model_attr):
return Optional[model_field.type]
if _is_foreign_key(model_attr):
return Optional[model_field.type]
if cls._is_column_property(model_attr):
return Optional[model_field.type]
return model_field.type
def _filter_excluded(
attrs: dict[str, orm.MapperProperty | orm.Relationship],
) -> dict[str, orm.MapperProperty | orm.Relationship]:
excluded = cls.dto_exclude_attributes()
return {n: a for n, a in attrs.items() if n not in excluded}
def _process_direct_attributes(
attrs: dict[str, orm.MapperProperty]
) -> dict[str, tuple[Type, pydantic.Field]]:
attrs = _filter_excluded(attrs)
items = {n: (f, attrs[n]) for n, f in cls.fields().items() if n in attrs}
return {
name: (_to_dto_type(*args), _to_dto_field(*args))
for name, args in items.items()
}
def _process_regular_columns() -> dict[str, tuple[Type, pydantic.Field]]:
return _process_direct_attributes(cls.regular_columns())
def _process_column_properties() -> dict[str, tuple[Type, pydantic.Field]]:
return _process_direct_attributes(cls.column_properties())
def _process_composites() -> dict[str, tuple[Type, pydantic.Field]]:
return _process_direct_attributes(cls.composite_columns())
def _process_hybrid_properties() -> dict[str, tuple[Type, pydantic.Field]]:
attrs = _filter_excluded(cls.hybrid_properties())
result = {}
for name, attr in attrs.items():
inner_type = get_type_hints(attr.fget).get("return", None)
dto_type = Optional[inner_type]
dto_field = pydantic.Field(init=True, default=None)
result[name] = (dto_type, dto_field)
return result
def _is_parent_relationship(rel: orm.Relationship) -> bool:
owns_fk = any(c.foreign_keys for c in rel.local_columns)
is_primary_association = _is_primary_association(rel)
return not owns_fk and is_primary_association
def _is_primary_association(rel: orm.Relationship) -> bool:
"""True if the local column is the first entry in the association table."""
secondary = rel.secondary
# This is only applicable to MtM relationships with an
# association table.
if secondary is None:
return True
local_columns = rel.local_columns
foreign_keys = secondary.columns[0].foreign_keys
return all(fk.column in local_columns for fk in foreign_keys)
def _process_relationships() -> dict[str, tuple[Type, pydantic.Field]]:
fields = cls.fields()
relationships = _filter_excluded(cls.relationships())
pairs = {n: (fields[n], r) for n, r in relationships.items()}
pairs = {n: p for n, p in pairs.items() if _is_parent_relationship(p[1])}
result = {}
for name, (model_field, relationship) in pairs.items():
collection_class = relationship.collection_class
dto_field = _to_dto_field(model_field, relationship)
model_type = relationship.entity.class_
# Add the relationship class to the set of additional classes
# we may need to process later.
forward_declarations.add(model_type)
# We need to use a forward declaration of the relationship
# DTO type, otherwise we'll run into cyclic recursion issues.
dto_entity_type = model_type.dto_import_path()
if relationship.uselist:
dto_type = collection_class[dto_entity_type] # noqa
else:
dto_type = dto_entity_type
result[name] = (dto_type, dto_field)
return result
# region DTO methods
# These methods will be attached to the DTO class
def _dto_pretty_print(self, line_length=1):
# Pretty print the DTO
# See: https://github.com/pydantic/pydantic/discussions/7787#discussioncomment-9658140
import black
s = repr(self)
return black.format_str(s, mode=black.FileMode(line_length=line_length))
def _dto_natural_keys(self) -> list[str]:
# Return the list of natural key attribute names
return list(cls.natural_key_columns().keys())
def _dto_hash(self) -> int:
value = tuple((a, getattr(self, a, None)) for a in self.natural_keys())
return self.hasher.hash(value)
# endregion DTO methods
def _make_dto_class(attributes: dict[str, tuple[Type, pydantic.Field]]) -> type:
exclude = cls.dto_exclude_attributes()
attributes = {k: v for k, v in attributes.items() if k not in exclude}
dto_types = {k: v[0] for k, v in attributes.items()}
dto_fields = {k: v[1] for k, v in attributes.items()}
# Enable creating a DTO from a dict
dto_types["Config"] = ClassVar[Type]
# Add a hasher
dto_types["hasher"] = ClassVar[ObjectHasher]
return type(
cls.dto_class_name(),
(pydantic.BaseModel,),
{
"__module__": cls.dto_module_name(),
"__annotations__": dto_types,
"__hash__": _dto_hash,
"Config": Config,
"hasher": cls.hasher,
"natural_keys": _dto_natural_keys,
"pretty_print": _dto_pretty_print,
**dto_fields,
},
)
# endregion _do_dto_conversion utility functions
dto_attributes = {}
dto_attributes |= _process_regular_columns()
dto_attributes |= _process_column_properties()
dto_attributes |= _process_composites()
dto_attributes |= _process_hybrid_properties()
dto_attributes |= _process_relationships()
dto_class = _make_dto_class(dto_attributes)
dto_class = cast(DtoModel, dto_class)
# Register the DTO class with the BaseModel
BaseModel._registered_dto_classes[cls] = dto_class
# Add the class to the appropriate module
setattr(cls.dto_module(), cls.dto_class_name(), dto_class)
return dto_class
# endregion Utility Methods
@dataclass(frozen=True, kw_only=False)
[docs]
class SqlBinaryExpression(YamlConfig):
"""A class that represents the basic binary expression for an SQL column."""
"""
The column name.
"""
"""
The operator to compare the ``column`` and ``value`` with.
"""
"""
The value used as a comparison.
"""
[docs]
def __call__(self, model_or_table: Type[B] | sa.Table) -> sa.BinaryExpression:
return self.to_expression(model_or_table)
@require(
lambda self, model_or_table: (
self.column in model_or_table.columns
if isinstance(model_or_table, sa.Table)
else hasattr(model_or_table, self.column)
),
"Invalid column",
)
@require(
lambda model_or_table: isinstance(model_or_table, sa.Table)
or issubclass(model_or_table, orm.DeclarativeBase),
(
"The 'model_or_table' must either be an SqlAlchemy Table or an SqlAlchemy "
"ORM model (subclass of DeclarativeBase)."
),
)
[docs]
def to_expression(self, model_or_table: Type[B] | sa.Table) -> sa.BinaryExpression:
"""Return a clause that can be used with an SqlAlchemy ``where`` statement.
Parameters
----------
model_or_table : sqlalchemy.Table
The table object that contains the column.
Returns
-------
BinaryExpression
The corresponding SqlAlchemy binary expression.
"""
column = (
model_or_table.c[self.column]
if isinstance(model_or_table, sa.Table)
else getattr(model_or_table, self.column)
)
operator = self.operator
if operator == "==":
return column == self.value # noqa
if operator == "==" and self.value is None:
return column.is_(None)
if operator == "!=":
return column != self.value # noqa
if operator == "!=" and self.value is None:
return column.isnot(None)
if operator == ">":
return column > self.value # noqa
if operator == ">=":
return column >= self.value # noqa
if operator == "<":
return column < self.value # noqa
if operator == "<=":
return column <= self.value # noqa
if operator == "like":
return column.like(self.value)
if operator == "in":
return column.in_(self.value)
if operator == "not_in":
return column.not_in(self.value)
if operator == "is_null":
return column.is_(None)
if operator == "is_not_null":
return column.isnot(None)
else:
raise ValueError(f"Invalid operator: {operator}")
@dataclass(kw_only=False, frozen=True)
[docs]
class SqlSelectionCriteria(YamlConfig):
"""A class that represents a conjunction of SqlBinaryExpression."""
[docs]
expressions: list[SqlBinaryExpression]
"""
The list of binary expressions that will be used to filter the query.
"""
@require(
lambda model_or_table: isinstance(model_or_table, sa.Table)
or issubclass(model_or_table, orm.DeclarativeBase),
(
"The 'model_or_table' must either be an SqlAlchemy Table or an SqlAlchemy "
"ORM model (subclass of DeclarativeBase)."
),
)
[docs]
def to_conjunction(
self, model_or_table: Type[B] | sa.Table
) -> sa.BinaryExpression | sa.BooleanClauseList:
"""
Return a conjunction of binary expressions that can be used with an
SqlAlchemy ``where`` statement.
Returns
-------
BinaryExpression | BooleanClauseList
"""
table = (
model_or_table
if isinstance(model_or_table, sa.Table)
else model_or_table.__table__
)
return sa.and_(*[e.to_expression(table) for e in self.expressions])
@dataclass(kw_only=False, frozen=True)
[docs]
class SqlOrderExpression(YamlConfig):
"""A class that represents a basic order expression for an SQL column."""
"""The name of the column to sort by."""
"""Whether to sort the column in ascending order."""
[docs]
def to_expression(self, model_or_table: Type[B] | sa.Table) -> sa.UnaryExpression:
"""Convert the order expression to a SqlAlchemy ``UnaryExpression``."""
column = (
model_or_table.c[self.column]
if isinstance(model_or_table, sa.Table)
else getattr(model_or_table, self.column)
)
return column.asc() if self.ascending else column.desc()
@dataclass(kw_only=False, frozen=True)
[docs]
class SqlOrderCriteria(YamlConfig):
"""A class that represents a list of ``SqlOrderExpressions``."""
[docs]
expressions: list[SqlOrderExpression]
"""The list of order expressions."""
[docs]
def to_criteria(
self, model_or_table: Type[B] | sa.Table
) -> list[sa.UnaryExpression]:
"""
Convert the order criteria to a list of SqlAlchemy ``UnaryExpressions``
that can be used with an SqlAlchemy ``order_by`` statement."""
return [e.to_expression(model_or_table) for e in self.expressions]
@dataclass(kw_only=True)
[docs]
class DatabaseConnectionConfig(abc.ABC):
"""The options necessary to connect to a database using SqlAlchemy."""
"""The name of the database or the path to the database file if using sqlite3."""
[docs]
drivername: str = "sqlite"
"""The name of the driver used to connect to the database."""
[docs]
username: Optional[str] = None
"""The user name to use when connecting to the database, if needed."""
[docs]
password: Optional[str] = None
"""The password to use when connecting to the database, if needed."""
[docs]
host: Optional[str] = None
"""The database host, if applicable."""
[docs]
port: Optional[int] = None
"""The database port, if applicable."""
[docs]
use_async: bool = False
"""Whether to use async mode."""
@property
[docs]
def url(self) -> sa.URL:
"""Return the :class:`sqlalchemy.URL` representation of this class."""
return sa.URL.create(
drivername=self.drivername,
username=self.username,
password=self.password,
host=self.host,
port=self.port,
database=self.database,
)
[docs]
class Repository:
"""A class implementing the basic CRUD operations for the data access layer."""
def __init__(self, engine: AlchemyEngine, model_class: Type[M]):
[docs]
self.model_class = model_class
[docs]
def insert(
self,
instances: BaseModel | Iterable[BaseModel],
session: Optional[Session] = None,
):
"""
Insert one or more model instances into the database.
Parameters
----------
instances : BaseModel | Iterable[BaseModel]
The data to insert.
session : Session | AsyncSession, optional
A session to use.
Raises
------
Exception
If there is a problem adding the data to the database.
"""
instances = self._normalize_data(instances)
with self._get_sync_session(session) as local_session:
local_session.add_all(instances)
[docs]
async def async_insert(
self,
instances: BaseModel | Iterable[BaseModel],
session: Optional[AsyncSession] = None,
):
"""
Insert one or more model instances into the database.
Parameters
----------
instances : BaseModel | Iterable[BaseModel]
The data to insert.
session : Session | AsyncSession, optional
A session to use.
Raises
------
Exception
If there is a problem adding the data to the database.
"""
instances = self._normalize_data(instances)
async with self._get_async_session(session) as local_session:
local_session.add_all(instances)
[docs]
def upsert(
self,
instances: BaseModel | Iterable[BaseModel],
on_conflict: ConflictResolutionStrategy = "do_nothing",
session: Optional[Session] = None,
):
"""
Upsert one or more model instances into the database.
..note::
This only works with dialects that support INSERT ... ON CONFLICT.
This should include ``postgresql``, ``mysql/mariadb``, and
``sqlite``. However, currently only ``postgresql`` is supported.
..note::
This method creates several copies of the data, so you should be
careful about memory management if you are inserting a large number
of objects.
..warning::
This only works if all non-null fields are provided **including
primary and foreign keys.**
..warning::
This is not well tested with complex model configurations such as
hybrid properties and column properties.
Parameters
----------
instances
on_conflict
session
"""
statements = self._make_upsert_statements(instances, on_conflict)
for stmt in statements:
with self._get_sync_session(session) as local_session:
local_session.execute(stmt)
[docs]
async def async_upsert(
self,
instances: BaseModel | Iterable[BaseModel],
on_conflict: ConflictResolutionStrategy = "do_nothing",
session: Optional[AsyncSession] = None,
):
"""
Upsert one or more model instances into the database.
..note::
This only works with dialects that support INSERT ... ON CONFLICT.
This should include ``postgresql``, ``mysql/mariadb``, and
``sqlite``. However, currently only ``postgresql`` is supported.
..note::
This method creates several copies of the data, so you should be
careful about memory management if you are inserting a large number
of objects.
..warning::
This only works if all non-null fields are provided **including
primary and foreign keys.**
..warning::
This is not well tested with complex model configurations such as
hybrid properties and column properties.
Parameters
----------
instances
on_conflict
session
"""
statements = self._make_upsert_statements(instances, on_conflict)
for stmt in statements:
with self._get_async_session(session) as local_session:
await local_session.execute(stmt)
@require(
lambda filter_by: all(isinstance(f, SqlBinaryExpression) for f in filter_by)
)
@require(lambda order_by: all(isinstance(o, SqlOrderExpression) for o in order_by))
@require(lambda limit: limit is None or isinstance(limit, int))
[docs]
def find_all(
self,
filter_by: Iterable[SqlBinaryExpression] = (),
order_by: Iterable[SqlOrderExpression] = (),
limit: Optional[int] = None,
session: Optional[Session] = None,
) -> list[DtoModel]:
"""
Finds and retrieves multiple models from the database based on specified
filters, ordering, and a limit. Converts the retrieved models to their
corresponding DTO (Data Transfer Object) representation.
Parameters
----------
filter_by : Iterable[SqlBinaryExpression], optional
Collection of SQL binary expressions defining the conditions for
filtering the query. Default is an empty iterable.
order_by : Iterable[SqlOrderExpression], optional
Collection of SQL order expressions defining the sorting order for the
query. Default is an empty iterable.
limit : int, optional
Maximum number of records to retrieve. If None, no limit is applied.
Default is None.
session : Session, optional
An optional database session instance. If not provided, a new session
will be created internally for executing the query.
Returns
-------
list[DtoModel]
A list of DTO instances that correspond to the retrieved database models.
"""
stmt = self._select_stmt(self.model_class, filter_by, order_by, limit)
with self._get_sync_session(session) as local_session:
models = local_session.scalars(stmt).all()
dtos = [m.to_dto() for m in models]
return dtos
[docs]
async def async_find_all(
self,
filter_by: Iterable[SqlBinaryExpression] = (),
order_by: Iterable[SqlOrderExpression] = (),
limit: Optional[int] = None,
session: Optional[Session] = None,
) -> list[DtoModel]:
"""
Asynchronously, finds and retrieves multiple models from the database based on
specified filters, ordering, and a limit. Converts the retrieved models to
their corresponding DTO (Data Transfer Object) representation.
Parameters
----------
filter_by : Iterable[SqlBinaryExpression], optional
Collection of SQL binary expressions defining the conditions for
filtering the query. Default is an empty iterable.
order_by : Iterable[SqlOrderExpression], optional
Collection of SQL order expressions defining the sorting order for the
query. Default is an empty iterable.
limit : int, optional
Maximum number of records to retrieve. If None, no limit is applied.
Default is None.
session : Session, optional
An optional database session instance. If not provided, a new session
will be created internally for executing the query.
Returns
-------
list[DtoModel]
A list of DTO instances that correspond to the retrieved database models.
"""
stmt = self._select_stmt(self.model_class, filter_by, order_by, limit)
async with self._get_async_session(session) as local_session:
models = (await local_session.scalars(stmt)).all()
dtos = [m.to_dto() for m in models]
return dtos
[docs]
def find_one(
self,
filter_by: Iterable[SqlBinaryExpression] = (),
session: Optional[Session] = None,
raise_on_none: bool = True,
) -> DtoModel:
"""
Retrieve a single record from the database that matches the provided filter
criteria. Converts the obtained database model into a data transfer object
(DTO). An exception is raised when more than one result is returned or no
results are found when `raise_on_none` is True.
Parameters
----------
filter_by : Iterable[SqlBinaryExpression], optional
Filter conditions to apply when querying the database. Defaults to an
empty tuple.
session : Optional[Session], optional
SQLAlchemy session to use for querying. If not provided, a new session
will be created and used for the operation.
raise_on_none : bool, optional
Specifies whether to raise an exception if no records match the specified
filter criteria. Defaults to True.
Returns
-------
DtoModel
The retrieved data model in DTO form if a record is found; otherwise, None
when `raise_on_none` is set to False.
"""
# 'find_one' doesn't need an 'order_by', because it will be an exception
# if more than one result is found by the query.
stmt = self._select_stmt(self.model_class, filter_by, ())
with self._get_sync_session(session) as local_session:
result = local_session.scalars(stmt)
model = result.one() if raise_on_none else result.one_or_none()
dto = model.to_dto() if model is not None else None
return dto
[docs]
async def async_find_one(
self,
filter_by: Iterable[SqlBinaryExpression] = (),
session: Optional[Session] = None,
raise_on_none: bool = True,
) -> DtoModel:
"""
Asynchronously, retrieve a single record from the database that matches the
provided filter criteria. Converts the obtained database model into a data
transfer object (DTO). An exception is raised when more than one result is
returned or no results are found when `raise_on_none` is True.
Parameters
----------
filter_by : Iterable[SqlBinaryExpression], optional
Filter conditions to apply when querying the database. Defaults to an
empty tuple.
session : Optional[Session], optional
SQLAlchemy session to use for querying. If not provided, a new session
will be created and used for the operation.
raise_on_none : bool, optional
Specifies whether to raise an exception if no records match the specified
filter criteria. Defaults to True.
Returns
-------
DtoModel
The retrieved data model in DTO form if a record is found; otherwise, None
when `raise_on_none` is set to False.
"""
stmt = self._select_stmt(self.model_class, filter_by, ())
async with self._get_async_session(session) as local_session:
result = await local_session.scalars(stmt)
model = result.one() if raise_on_none else result.one_or_none()
dto = model.to_dto() if model is not None else None
return dto
@classmethod
def _select_stmt(
cls,
model_class: Type[M],
filter_by: Iterable[SqlBinaryExpression],
order_by: Iterable[SqlOrderExpression],
limit: Optional[int] = None,
):
filter_criteria = SqlSelectionCriteria(list(filter_by))
order_criteria = SqlOrderCriteria(list(order_by))
return (
sa.select(model_class)
.where(filter_criteria.to_conjunction(model_class))
.order_by(*order_criteria.to_criteria(model_class))
.limit(limit)
)
@classmethod
def _normalize_data(
cls, data: BaseModel | Iterable[BaseModel]
) -> Iterable[BaseModel]:
if not is_iterable(data):
data = [data]
return data
@contextmanager
def _get_sync_session(self, session: Optional[Session]) -> Session:
if session is None:
with self.engine.session_scope() as engine_session:
yield engine_session
else:
yield session
@asynccontextmanager
async def _get_async_session(self, session: Optional[AsyncSession]) -> AsyncSession:
if session is None:
async with self.engine.async_session_scope() as engine_session:
yield engine_session
else:
yield session
# region Upsert Utilities
@classmethod
def _group_as_dict(
cls, data: Iterable[BaseModel]
) -> dict[Type[B], list[dict[str, Any]]]:
result: dict[Type[B], list[dict[str, Any]]] = defaultdict(list)
def add_to_result(n: BaseModel):
"""Convert the node to a dictionary and add it to the result."""
table = cast(sa.Table, n.__table__)
result[n.__class__].append(
{c.key: getattr(n, c.key, None) for c in n.table_columns()}
)
for item in data:
item.walk_children(add_to_result, traverse_viewonly=False)
return result
@classmethod
def _filter_duplicate_values(
cls, data: dict[Type[B], list[dict[str, Any]]]
) -> dict[Type[B], list[dict[str, Any]]]:
# Convert each dictionary to a set of tuples, which can be inserted
# into a set and have duplicates removed.
tmp = {model: {set(r.items()) for r in rows} for model, rows in data.items()}
# Convert the tuple representation back to dictionaries.
return {model: [dict(t) for t in tuples] for model, tuples in tmp.items()}
@classmethod
def _sort_by_insertion_order(
cls, data: dict[Type[B], list[dict[str, Any]]]
) -> dict[Type[B], list[dict[str, Any]]]:
order_lookup = BaseModel.table_insertion_order()
return {
k: v
for k, v in sorted(data.items(), key=lambda e: order_lookup[e[0].__table__])
}
def _make_upsert_statement(
self,
model_class: Type[B],
model_values: list[dict[str, Any]],
on_conflict: Optional[ConflictResolutionStrategy],
) -> Insert:
def _get_insert_function():
if self.engine.dialect == "postgresql":
return postgres_insert
elif self.engine.dialect == "mysql":
return mysql_insert
elif self.engine.dialect == "sqlite":
return sqlite_insert
else:
log.warning(
f"The dialect '{self.engine.dialect}' does not support upserts. "
f"Performing a regular insert instead."
)
return default_insert
insert = _get_insert_function()
pk_columns = model_class.primary_key_column_names()
# The base insert
stmt = insert(model_class).values(model_values)
if on_conflict == "do_nothing":
stmt = stmt.on_conflict_do_nothing(index_elements=pk_columns)
elif on_conflict == "do_update":
stmt = stmt.on_conflict_do_update(
index_elements=pk_columns,
set_={
c.name: c
for c in stmt.excluded # noqa SqlAlchemy
if c.name not in pk_columns
},
)
return stmt
def _make_upsert_statements(
self,
data: BaseModel | Iterable[BaseModel],
on_conflict: ConflictResolutionStrategy,
) -> list[Insert]:
data = self._normalize_data(data)
grouped_as_dict = self._group_as_dict(data)
grouped_as_dict = self._filter_duplicate_values(grouped_as_dict)
grouped_as_dict = self._sort_by_insertion_order(grouped_as_dict)
return [
self._make_upsert_statement(m, v, on_conflict)
for m, v in grouped_as_dict.items()
]