chore: automatic commit 2025-04-30 12:48

This commit is contained in:
2025-04-30 12:48:06 +02:00
parent f69356473b
commit e4ab1e1bb5
5284 changed files with 868438 additions and 0 deletions

View File

@@ -0,0 +1,46 @@
from __future__ import annotations
import typing as t
from .extension import SQLAlchemy
__version__ = "3.0.5"
__all__ = [
"SQLAlchemy",
]
_deprecated_map = {
"Model": ".model.Model",
"DefaultMeta": ".model.DefaultMeta",
"Pagination": ".pagination.Pagination",
"BaseQuery": ".query.Query",
"get_debug_queries": ".record_queries.get_recorded_queries",
"SignallingSession": ".session.Session",
"before_models_committed": ".track_modifications.before_models_committed",
"models_committed": ".track_modifications.models_committed",
}
def __getattr__(name: str) -> t.Any:
import importlib
import warnings
if name in _deprecated_map:
path = _deprecated_map[name]
import_path, _, new_name = path.rpartition(".")
action = "moved and renamed"
if new_name == name:
action = "moved"
warnings.warn(
f"'{name}' has been {action} to '{path[1:]}'. The top-level import is"
" deprecated and will be removed in Flask-SQLAlchemy 3.1.",
DeprecationWarning,
stacklevel=2,
)
mod = importlib.import_module(import_path, __name__)
return getattr(mod, new_name)
raise AttributeError(name)

View File

@@ -0,0 +1,16 @@
from __future__ import annotations
import typing as t
from flask import current_app
def add_models_to_shell() -> dict[str, t.Any]:
"""Registered with :meth:`~flask.Flask.shell_context_processor` if
``add_models_to_shell`` is enabled. Adds the ``db`` instance and all model classes
to ``flask shell``.
"""
db = current_app.extensions["sqlalchemy"]
out = {m.class_.__name__: m.class_ for m in db.Model._sa_registry.mappers}
out["db"] = db
return out

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,214 @@
from __future__ import annotations
import re
import typing as t
import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
from .query import Query
if t.TYPE_CHECKING:
from .extension import SQLAlchemy
class _QueryProperty:
"""A class property that creates a query object for a model.
:meta private:
"""
@t.overload
def __get__(self, obj: None, cls: type[Model]) -> Query:
...
@t.overload
def __get__(self, obj: Model, cls: type[Model]) -> Query:
...
def __get__(self, obj: Model | None, cls: type[Model]) -> Query:
return cls.query_class(
cls, session=cls.__fsa__.session() # type: ignore[arg-type]
)
class Model:
"""The base class of the :attr:`.SQLAlchemy.Model` declarative model class.
To define models, subclass :attr:`db.Model <.SQLAlchemy.Model>`, not this. To
customize ``db.Model``, subclass this and pass it as ``model_class`` to
:class:`.SQLAlchemy`. To customize ``db.Model`` at the metaclass level, pass an
already created declarative model class as ``model_class``.
"""
__fsa__: t.ClassVar[SQLAlchemy]
"""Internal reference to the extension object.
:meta private:
"""
query_class: t.ClassVar[type[Query]] = Query
"""Query class used by :attr:`query`. Defaults to :attr:`.SQLAlchemy.Query`, which
defaults to :class:`.Query`.
"""
query: t.ClassVar[Query] = _QueryProperty() # type: ignore[assignment]
"""A SQLAlchemy query for a model. Equivalent to ``db.session.query(Model)``. Can be
customized per-model by overriding :attr:`query_class`.
.. warning::
The query interface is considered legacy in SQLAlchemy. Prefer using
``session.execute(select())`` instead.
"""
def __repr__(self) -> str:
state = sa.inspect(self)
assert state is not None
if state.transient:
pk = f"(transient {id(self)})"
elif state.pending:
pk = f"(pending {id(self)})"
else:
pk = ", ".join(map(str, state.identity))
return f"<{type(self).__name__} {pk}>"
class BindMetaMixin(type):
"""Metaclass mixin that sets a model's ``metadata`` based on its ``__bind_key__``.
If the model sets ``metadata`` or ``__table__`` directly, ``__bind_key__`` is
ignored. If the ``metadata`` is the same as the parent model, it will not be set
directly on the child model.
"""
__fsa__: SQLAlchemy
metadata: sa.MetaData
def __init__(
cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any], **kwargs: t.Any
) -> None:
if not ("metadata" in cls.__dict__ or "__table__" in cls.__dict__):
bind_key = getattr(cls, "__bind_key__", None)
parent_metadata = getattr(cls, "metadata", None)
metadata = cls.__fsa__._make_metadata(bind_key)
if metadata is not parent_metadata:
cls.metadata = metadata
super().__init__(name, bases, d, **kwargs)
class NameMetaMixin(type):
"""Metaclass mixin that sets a model's ``__tablename__`` by converting the
``CamelCase`` class name to ``snake_case``. A name is set for non-abstract models
that do not otherwise define ``__tablename__``. If a model does not define a primary
key, it will not generate a name or ``__table__``, for single-table inheritance.
"""
metadata: sa.MetaData
__tablename__: str
__table__: sa.Table
def __init__(
cls, name: str, bases: tuple[type, ...], d: dict[str, t.Any], **kwargs: t.Any
) -> None:
if should_set_tablename(cls):
cls.__tablename__ = camel_to_snake_case(cls.__name__)
super().__init__(name, bases, d, **kwargs)
# __table_cls__ has run. If no table was created, use the parent table.
if (
"__tablename__" not in cls.__dict__
and "__table__" in cls.__dict__
and cls.__dict__["__table__"] is None
):
del cls.__table__
def __table_cls__(cls, *args: t.Any, **kwargs: t.Any) -> sa.Table | None:
"""This is called by SQLAlchemy during mapper setup. It determines the final
table object that the model will use.
If no primary key is found, that indicates single-table inheritance, so no table
will be created and ``__tablename__`` will be unset.
"""
schema = kwargs.get("schema")
if schema is None:
key = args[0]
else:
key = f"{schema}.{args[0]}"
# Check if a table with this name already exists. Allows reflected tables to be
# applied to models by name.
if key in cls.metadata.tables:
return sa.Table(*args, **kwargs)
# If a primary key is found, create a table for joined-table inheritance.
for arg in args:
if (isinstance(arg, sa.Column) and arg.primary_key) or isinstance(
arg, sa.PrimaryKeyConstraint
):
return sa.Table(*args, **kwargs)
# If no base classes define a table, return one that's missing a primary key
# so SQLAlchemy shows the correct error.
for base in cls.__mro__[1:-1]:
if "__table__" in base.__dict__:
break
else:
return sa.Table(*args, **kwargs)
# Single-table inheritance, use the parent table name. __init__ will unset
# __table__ based on this.
if "__tablename__" in cls.__dict__:
del cls.__tablename__
return None
def should_set_tablename(cls: type) -> bool:
"""Determine whether ``__tablename__`` should be generated for a model.
- If no class in the MRO sets a name, one should be generated.
- If a declared attr is found, it should be used instead.
- If a name is found, it should be used if the class is a mixin, otherwise one
should be generated.
- Abstract models should not have one generated.
Later, ``__table_cls__`` will determine if the model looks like single or
joined-table inheritance. If no primary key is found, the name will be unset.
"""
if cls.__dict__.get("__abstract__", False) or not any(
isinstance(b, sa_orm.DeclarativeMeta) for b in cls.__mro__[1:]
):
return False
for base in cls.__mro__:
if "__tablename__" not in base.__dict__:
continue
if isinstance(base.__dict__["__tablename__"], sa_orm.declared_attr):
return False
return not (
base is cls
or base.__dict__.get("__abstract__", False)
or not isinstance(base, sa_orm.DeclarativeMeta)
)
return True
def camel_to_snake_case(name: str) -> str:
"""Convert a ``CamelCase`` name to ``snake_case``."""
name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name)
return name.lower().lstrip("_")
class DefaultMeta(BindMetaMixin, NameMetaMixin, sa_orm.DeclarativeMeta):
"""SQLAlchemy declarative metaclass that provides ``__bind_key__`` and
``__tablename__`` support.
"""

View File

@@ -0,0 +1,364 @@
from __future__ import annotations
import typing as t
from math import ceil
import sqlalchemy as sa
import sqlalchemy.orm as sa_orm
from flask import abort
from flask import request
class Pagination:
"""Apply an offset and limit to the query based on the current page and number of
items per page.
Don't create pagination objects manually. They are created by
:meth:`.SQLAlchemy.paginate` and :meth:`.Query.paginate`.
This is a base class, a subclass must implement :meth:`_query_items` and
:meth:`_query_count`. Those methods will use arguments passed as ``kwargs`` to
perform the queries.
:param page: The current page, used to calculate the offset. Defaults to the
``page`` query arg during a request, or 1 otherwise.
:param per_page: The maximum number of items on a page, used to calculate the
offset and limit. Defaults to the ``per_page`` query arg during a request,
or 20 otherwise.
:param max_per_page: The maximum allowed value for ``per_page``, to limit a
user-provided value. Use ``None`` for no limit. Defaults to 100.
:param error_out: Abort with a ``404 Not Found`` error if no items are returned
and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if
either are not ints.
:param count: Calculate the total number of values by issuing an extra count
query. For very complex queries this may be inaccurate or slow, so it can be
disabled and set manually if necessary.
:param kwargs: Information about the query to paginate. Different subclasses will
require different arguments.
.. versionchanged:: 3.0
Iterating over a pagination object iterates over its items.
.. versionchanged:: 3.0
Creating instances manually is not a public API.
"""
def __init__(
self,
page: int | None = None,
per_page: int | None = None,
max_per_page: int | None = 100,
error_out: bool = True,
count: bool = True,
**kwargs: t.Any,
) -> None:
self._query_args = kwargs
page, per_page = self._prepare_page_args(
page=page,
per_page=per_page,
max_per_page=max_per_page,
error_out=error_out,
)
self.page: int = page
"""The current page."""
self.per_page: int = per_page
"""The maximum number of items on a page."""
self.max_per_page: int | None = max_per_page
"""The maximum allowed value for ``per_page``."""
items = self._query_items()
if not items and page != 1 and error_out:
abort(404)
self.items: list[t.Any] = items
"""The items on the current page. Iterating over the pagination object is
equivalent to iterating over the items.
"""
if count:
total = self._query_count()
else:
total = None
self.total: int | None = total
"""The total number of items across all pages."""
@staticmethod
def _prepare_page_args(
*,
page: int | None = None,
per_page: int | None = None,
max_per_page: int | None = None,
error_out: bool = True,
) -> tuple[int, int]:
if request:
if page is None:
try:
page = int(request.args.get("page", 1))
except (TypeError, ValueError):
if error_out:
abort(404)
page = 1
if per_page is None:
try:
per_page = int(request.args.get("per_page", 20))
except (TypeError, ValueError):
if error_out:
abort(404)
per_page = 20
else:
if page is None:
page = 1
if per_page is None:
per_page = 20
if max_per_page is not None:
per_page = min(per_page, max_per_page)
if page < 1:
if error_out:
abort(404)
else:
page = 1
if per_page < 1:
if error_out:
abort(404)
else:
per_page = 20
return page, per_page
@property
def _query_offset(self) -> int:
"""The index of the first item to query, passed to ``offset()``.
:meta private:
.. versionadded:: 3.0
"""
return (self.page - 1) * self.per_page
def _query_items(self) -> list[t.Any]:
"""Execute the query to get the items on the current page.
Uses init arguments stored in :attr:`_query_args`.
:meta private:
.. versionadded:: 3.0
"""
raise NotImplementedError
def _query_count(self) -> int:
"""Execute the query to get the total number of items.
Uses init arguments stored in :attr:`_query_args`.
:meta private:
.. versionadded:: 3.0
"""
raise NotImplementedError
@property
def first(self) -> int:
"""The number of the first item on the page, starting from 1, or 0 if there are
no items.
.. versionadded:: 3.0
"""
if len(self.items) == 0:
return 0
return (self.page - 1) * self.per_page + 1
@property
def last(self) -> int:
"""The number of the last item on the page, starting from 1, inclusive, or 0 if
there are no items.
.. versionadded:: 3.0
"""
first = self.first
return max(first, first + len(self.items) - 1)
@property
def pages(self) -> int:
"""The total number of pages."""
if self.total == 0 or self.total is None:
return 0
return ceil(self.total / self.per_page)
@property
def has_prev(self) -> bool:
"""``True`` if this is not the first page."""
return self.page > 1
@property
def prev_num(self) -> int | None:
"""The previous page number, or ``None`` if this is the first page."""
if not self.has_prev:
return None
return self.page - 1
def prev(self, *, error_out: bool = False) -> Pagination:
"""Query the :class:`Pagination` object for the previous page.
:param error_out: Abort with a ``404 Not Found`` error if no items are returned
and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if
either are not ints.
"""
p = type(self)(
page=self.page - 1,
per_page=self.per_page,
error_out=error_out,
count=False,
**self._query_args,
)
p.total = self.total
return p
@property
def has_next(self) -> bool:
"""``True`` if this is not the last page."""
return self.page < self.pages
@property
def next_num(self) -> int | None:
"""The next page number, or ``None`` if this is the last page."""
if not self.has_next:
return None
return self.page + 1
def next(self, *, error_out: bool = False) -> Pagination:
"""Query the :class:`Pagination` object for the next page.
:param error_out: Abort with a ``404 Not Found`` error if no items are returned
and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if
either are not ints.
"""
p = type(self)(
page=self.page + 1,
per_page=self.per_page,
max_per_page=self.max_per_page,
error_out=error_out,
count=False,
**self._query_args,
)
p.total = self.total
return p
def iter_pages(
self,
*,
left_edge: int = 2,
left_current: int = 2,
right_current: int = 4,
right_edge: int = 2,
) -> t.Iterator[int | None]:
"""Yield page numbers for a pagination widget. Skipped pages between the edges
and middle are represented by a ``None``.
For example, if there are 20 pages and the current page is 7, the following
values are yielded.
.. code-block:: python
1, 2, None, 5, 6, 7, 8, 9, 10, 11, None, 19, 20
:param left_edge: How many pages to show from the first page.
:param left_current: How many pages to show left of the current page.
:param right_current: How many pages to show right of the current page.
:param right_edge: How many pages to show from the last page.
.. versionchanged:: 3.0
Improved efficiency of calculating what to yield.
.. versionchanged:: 3.0
``right_current`` boundary is inclusive.
.. versionchanged:: 3.0
All parameters are keyword-only.
"""
pages_end = self.pages + 1
if pages_end == 1:
return
left_end = min(1 + left_edge, pages_end)
yield from range(1, left_end)
if left_end == pages_end:
return
mid_start = max(left_end, self.page - left_current)
mid_end = min(self.page + right_current + 1, pages_end)
if mid_start - left_end > 0:
yield None
yield from range(mid_start, mid_end)
if mid_end == pages_end:
return
right_start = max(mid_end, pages_end - right_edge)
if right_start - mid_end > 0:
yield None
yield from range(right_start, pages_end)
def __iter__(self) -> t.Iterator[t.Any]:
yield from self.items
class SelectPagination(Pagination):
"""Returned by :meth:`.SQLAlchemy.paginate`. Takes ``select`` and ``session``
arguments in addition to the :class:`Pagination` arguments.
.. versionadded:: 3.0
"""
def _query_items(self) -> list[t.Any]:
select = self._query_args["select"]
select = select.limit(self.per_page).offset(self._query_offset)
session = self._query_args["session"]
return list(session.execute(select).unique().scalars())
def _query_count(self) -> int:
select = self._query_args["select"]
sub = select.options(sa_orm.lazyload("*")).order_by(None).subquery()
session = self._query_args["session"]
out = session.execute(sa.select(sa.func.count()).select_from(sub)).scalar()
return out # type: ignore[no-any-return]
class QueryPagination(Pagination):
"""Returned by :meth:`.Query.paginate`. Takes a ``query`` argument in addition to
the :class:`Pagination` arguments.
.. versionadded:: 3.0
"""
def _query_items(self) -> list[t.Any]:
query = self._query_args["query"]
out = query.limit(self.per_page).offset(self._query_offset).all()
return out # type: ignore[no-any-return]
def _query_count(self) -> int:
# Query.count automatically disables eager loads
out = self._query_args["query"].order_by(None).count()
return out # type: ignore[no-any-return]

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
import typing as t
import sqlalchemy.exc as sa_exc
import sqlalchemy.orm as sa_orm
from flask import abort
from .pagination import Pagination
from .pagination import QueryPagination
class Query(sa_orm.Query): # type: ignore[type-arg]
"""SQLAlchemy :class:`~sqlalchemy.orm.query.Query` subclass with some extra methods
useful for querying in a web application.
This is the default query class for :attr:`.Model.query`.
.. versionchanged:: 3.0
Renamed to ``Query`` from ``BaseQuery``.
"""
def get_or_404(self, ident: t.Any, description: str | None = None) -> t.Any:
"""Like :meth:`~sqlalchemy.orm.Query.get` but aborts with a ``404 Not Found``
error instead of returning ``None``.
:param ident: The primary key to query.
:param description: A custom message to show on the error page.
"""
rv = self.get(ident)
if rv is None:
abort(404, description=description)
return rv
def first_or_404(self, description: str | None = None) -> t.Any:
"""Like :meth:`~sqlalchemy.orm.Query.first` but aborts with a ``404 Not Found``
error instead of returning ``None``.
:param description: A custom message to show on the error page.
"""
rv = self.first()
if rv is None:
abort(404, description=description)
return rv
def one_or_404(self, description: str | None = None) -> t.Any:
"""Like :meth:`~sqlalchemy.orm.Query.one` but aborts with a ``404 Not Found``
error instead of raising ``NoResultFound`` or ``MultipleResultsFound``.
:param description: A custom message to show on the error page.
.. versionadded:: 3.0
"""
try:
return self.one()
except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound):
abort(404, description=description)
def paginate(
self,
*,
page: int | None = None,
per_page: int | None = None,
max_per_page: int | None = None,
error_out: bool = True,
count: bool = True,
) -> Pagination:
"""Apply an offset and limit to the query based on the current page and number
of items per page, returning a :class:`.Pagination` object.
:param page: The current page, used to calculate the offset. Defaults to the
``page`` query arg during a request, or 1 otherwise.
:param per_page: The maximum number of items on a page, used to calculate the
offset and limit. Defaults to the ``per_page`` query arg during a request,
or 20 otherwise.
:param max_per_page: The maximum allowed value for ``per_page``, to limit a
user-provided value. Use ``None`` for no limit. Defaults to 100.
:param error_out: Abort with a ``404 Not Found`` error if no items are returned
and ``page`` is not 1, or if ``page`` or ``per_page`` is less than 1, or if
either are not ints.
:param count: Calculate the total number of values by issuing an extra count
query. For very complex queries this may be inaccurate or slow, so it can be
disabled and set manually if necessary.
.. versionchanged:: 3.0
All parameters are keyword-only.
.. versionchanged:: 3.0
The ``count`` query is more efficient.
.. versionchanged:: 3.0
``max_per_page`` defaults to 100.
"""
return QueryPagination(
query=self,
page=page,
per_page=per_page,
max_per_page=max_per_page,
error_out=error_out,
count=count,
)

View File

@@ -0,0 +1,141 @@
from __future__ import annotations
import dataclasses
import inspect
import typing as t
from time import perf_counter
import sqlalchemy as sa
import sqlalchemy.event as sa_event
from flask import current_app
from flask import g
from flask import has_app_context
def get_recorded_queries() -> list[_QueryInfo]:
"""Get the list of recorded query information for the current session. Queries are
recorded if the config :data:`.SQLALCHEMY_RECORD_QUERIES` is enabled.
Each query info object has the following attributes:
``statement``
The string of SQL generated by SQLAlchemy with parameter placeholders.
``parameters``
The parameters sent with the SQL statement.
``start_time`` / ``end_time``
Timing info about when the query started execution and when the results where
returned. Accuracy and value depends on the operating system.
``duration``
The time the query took in seconds.
``location``
A string description of where in your application code the query was executed.
This may not be possible to calculate, and the format is not stable.
.. versionchanged:: 3.0
Renamed from ``get_debug_queries``.
.. versionchanged:: 3.0
The info object is a dataclass instead of a tuple.
.. versionchanged:: 3.0
The info object attribute ``context`` is renamed to ``location``.
.. versionchanged:: 3.0
Not enabled automatically in debug or testing mode.
"""
return g.get("_sqlalchemy_queries", []) # type: ignore[no-any-return]
@dataclasses.dataclass
class _QueryInfo:
"""Information about an executed query. Returned by :func:`get_recorded_queries`.
.. versionchanged:: 3.0
Renamed from ``_DebugQueryTuple``.
.. versionchanged:: 3.0
Changed to a dataclass instead of a tuple.
.. versionchanged:: 3.0
``context`` is renamed to ``location``.
"""
statement: str | None
parameters: t.Any
start_time: float
end_time: float
location: str
@property
def duration(self) -> float:
return self.end_time - self.start_time
@property
def context(self) -> str:
import warnings
warnings.warn(
"'context' is renamed to 'location'. The old name is deprecated and will be"
" removed in Flask-SQLAlchemy 3.1.",
DeprecationWarning,
stacklevel=2,
)
return self.location
def __getitem__(self, key: int) -> object:
import warnings
name = ("statement", "parameters", "start_time", "end_time", "location")[key]
warnings.warn(
"Query info is a dataclass, not a tuple. Lookup by index is deprecated and"
f" will be removed in Flask-SQLAlchemy 3.1. Use 'info.{name}' instead.",
DeprecationWarning,
stacklevel=2,
)
return getattr(self, name)
def _listen(engine: sa.engine.Engine) -> None:
sa_event.listen(engine, "before_cursor_execute", _record_start, named=True)
sa_event.listen(engine, "after_cursor_execute", _record_end, named=True)
def _record_start(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None:
if not has_app_context():
return
context._fsa_start_time = perf_counter() # type: ignore[attr-defined]
def _record_end(context: sa.engine.ExecutionContext, **kwargs: t.Any) -> None:
if not has_app_context():
return
if "_sqlalchemy_queries" not in g:
g._sqlalchemy_queries = []
import_top = current_app.import_name.partition(".")[0]
import_dot = f"{import_top}."
frame = inspect.currentframe()
while frame:
name = frame.f_globals.get("__name__")
if name and (name == import_top or name.startswith(import_dot)):
code = frame.f_code
location = f"{code.co_filename}:{frame.f_lineno} ({code.co_name})"
break
frame = frame.f_back
else:
location = "<unknown>"
g._sqlalchemy_queries.append(
_QueryInfo(
statement=context.statement,
parameters=context.parameters,
start_time=context._fsa_start_time, # type: ignore[attr-defined]
end_time=perf_counter(),
location=location,
)
)

View File

@@ -0,0 +1,102 @@
from __future__ import annotations
import typing as t
import sqlalchemy as sa
import sqlalchemy.exc as sa_exc
import sqlalchemy.orm as sa_orm
from flask.globals import app_ctx
if t.TYPE_CHECKING:
from .extension import SQLAlchemy
class Session(sa_orm.Session):
"""A SQLAlchemy :class:`~sqlalchemy.orm.Session` class that chooses what engine to
use based on the bind key associated with the metadata associated with the thing
being queried.
To customize ``db.session``, subclass this and pass it as the ``class_`` key in the
``session_options`` to :class:`.SQLAlchemy`.
.. versionchanged:: 3.0
Renamed from ``SignallingSession``.
"""
def __init__(self, db: SQLAlchemy, **kwargs: t.Any) -> None:
super().__init__(**kwargs)
self._db = db
self._model_changes: dict[object, tuple[t.Any, str]] = {}
def get_bind(
self,
mapper: t.Any | None = None,
clause: t.Any | None = None,
bind: sa.engine.Engine | sa.engine.Connection | None = None,
**kwargs: t.Any,
) -> sa.engine.Engine | sa.engine.Connection:
"""Select an engine based on the ``bind_key`` of the metadata associated with
the model or table being queried. If no bind key is set, uses the default bind.
.. versionchanged:: 3.0.3
Fix finding the bind for a joined inheritance model.
.. versionchanged:: 3.0
The implementation more closely matches the base SQLAlchemy implementation.
.. versionchanged:: 2.1
Support joining an external transaction.
"""
if bind is not None:
return bind
engines = self._db.engines
if mapper is not None:
try:
mapper = sa.inspect(mapper)
except sa_exc.NoInspectionAvailable as e:
if isinstance(mapper, type):
raise sa_orm.exc.UnmappedClassError(mapper) from e
raise
engine = _clause_to_engine(mapper.local_table, engines)
if engine is not None:
return engine
if clause is not None:
engine = _clause_to_engine(clause, engines)
if engine is not None:
return engine
if None in engines:
return engines[None]
return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs)
def _clause_to_engine(
clause: t.Any | None, engines: t.Mapping[str | None, sa.engine.Engine]
) -> sa.engine.Engine | None:
"""If the clause is a table, return the engine associated with the table's
metadata's bind key.
"""
if isinstance(clause, sa.Table) and "bind_key" in clause.metadata.info:
key = clause.metadata.info["bind_key"]
if key not in engines:
raise sa_exc.UnboundExecutionError(
f"Bind key '{key}' is not in 'SQLALCHEMY_BINDS' config."
)
return engines[key]
return None
def _app_ctx_id() -> int:
"""Get the id of the current Flask application context for the session scope."""
return id(app_ctx._get_current_object()) # type: ignore[attr-defined]

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
import typing as t
import sqlalchemy as sa
import sqlalchemy.sql.schema as sa_sql_schema
class _Table(sa.Table):
@t.overload
def __init__(
self,
name: str,
*args: sa_sql_schema.SchemaItem,
bind_key: str | None = None,
**kwargs: t.Any,
) -> None:
...
@t.overload
def __init__(
self,
name: str,
metadata: sa.MetaData,
*args: sa_sql_schema.SchemaItem,
**kwargs: t.Any,
) -> None:
...
@t.overload
def __init__(
self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any
) -> None:
...
def __init__(
self, name: str, *args: sa_sql_schema.SchemaItem, **kwargs: t.Any
) -> None:
super().__init__(name, *args, **kwargs) # type: ignore[arg-type]

View File

@@ -0,0 +1,88 @@
from __future__ import annotations
import typing as t
import sqlalchemy as sa
import sqlalchemy.event as sa_event
import sqlalchemy.orm as sa_orm
from flask import current_app
from flask import has_app_context
from flask.signals import Namespace # type: ignore[attr-defined]
if t.TYPE_CHECKING:
from .session import Session
_signals = Namespace()
models_committed = _signals.signal("models-committed")
"""This Blinker signal is sent after the session is committed if there were changed
models in the session.
The sender is the application that emitted the changes. The receiver is passed the
``changes`` argument with a list of tuples in the form ``(instance, operation)``.
The operations are ``"insert"``, ``"update"``, and ``"delete"``.
"""
before_models_committed = _signals.signal("before-models-committed")
"""This signal works exactly like :data:`models_committed` but is emitted before the
commit takes place.
"""
def _listen(session: sa_orm.scoped_session[Session]) -> None:
sa_event.listen(session, "before_flush", _record_ops, named=True)
sa_event.listen(session, "before_commit", _record_ops, named=True)
sa_event.listen(session, "before_commit", _before_commit)
sa_event.listen(session, "after_commit", _after_commit)
sa_event.listen(session, "after_rollback", _after_rollback)
def _record_ops(session: Session, **kwargs: t.Any) -> None:
if not has_app_context():
return
if not current_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]:
return
for targets, operation in (
(session.new, "insert"),
(session.dirty, "update"),
(session.deleted, "delete"),
):
for target in targets:
state = sa.inspect(target)
key = state.identity_key if state.has_identity else id(target)
session._model_changes[key] = (target, operation)
def _before_commit(session: Session) -> None:
if not has_app_context():
return
app = current_app._get_current_object() # type: ignore[attr-defined]
if not app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]:
return
if session._model_changes:
changes = list(session._model_changes.values())
before_models_committed.send(app, changes=changes)
def _after_commit(session: Session) -> None:
if not has_app_context():
return
app = current_app._get_current_object() # type: ignore[attr-defined]
if not app.config["SQLALCHEMY_TRACK_MODIFICATIONS"]:
return
if session._model_changes:
changes = list(session._model_changes.values())
models_committed.send(app, changes=changes)
session._model_changes.clear()
def _after_rollback(session: Session) -> None:
session._model_changes.clear()