Transaction context manager with support for nested transactions via savepoints.
| 15 | |
| 16 | |
| 17 | class Transaction: |
| 18 | """ |
| 19 | Transaction context manager with support for nested transactions via savepoints. |
| 20 | """ |
| 21 | |
| 22 | def __init__( |
| 23 | self, |
| 24 | database: "DatabaseConnection", |
| 25 | force_rollback: bool = False, |
| 26 | ) -> None: |
| 27 | """ |
| 28 | Initialize transaction. |
| 29 | |
| 30 | :param database: DatabaseConnection instance |
| 31 | :param force_rollback: If True, always rollback (used for testing) |
| 32 | """ |
| 33 | self._database = database |
| 34 | self._force_rollback = force_rollback |
| 35 | self._connection: Optional[AsyncConnection] = None |
| 36 | self._transaction: Optional[AsyncTransaction] = None |
| 37 | self._depth: int = 0 |
| 38 | |
| 39 | async def __aenter__(self) -> "Transaction": |
| 40 | """Enter transaction context.""" |
| 41 | self._depth = _transaction_depth.get() |
| 42 | |
| 43 | # If this is the outermost transaction, get a new connection. |
| 44 | # This uses the main engine (not the AUTOCOMMIT view used by |
| 45 | # standalone queries) so ``connection.begin()`` and the nested |
| 46 | # ``begin_nested()`` savepoints below both work. |
| 47 | if self._depth == 0: |
| 48 | self._connection = await self._database.engine.connect().__aenter__() |
| 49 | self._database.set_transaction_connection(self._connection) |
| 50 | self._transaction = await self._connection.begin() |
| 51 | # SQLite requires an explicit BEGIN before SAVEPOINTs to prevent |
| 52 | # RELEASE SAVEPOINT from auto-committing when no outer transaction exists. |
| 53 | # Issue after conn.begin() to avoid conflicting with SQLAlchemy's autobegin. |
| 54 | if self._database.engine.dialect.name == "sqlite": # pragma: nocover |
| 55 | await self._connection.exec_driver_sql("BEGIN") |
| 56 | else: |
| 57 | # Nested transaction - use savepoint |
| 58 | self._connection = self._database.get_transaction_connection() |
| 59 | assert self._connection is not None |
| 60 | self._transaction = await self._connection.begin_nested() |
| 61 | |
| 62 | _transaction_depth.set(self._depth + 1) |
| 63 | |
| 64 | return self |
| 65 | |
| 66 | async def __aexit__( |
| 67 | self, |
| 68 | exc_type: Optional[Type[BaseException]] = None, |
| 69 | exc_value: Optional[BaseException] = None, |
| 70 | traceback: Optional[TracebackType] = None, |
| 71 | ) -> None: |
| 72 | """Exit transaction context.""" |
| 73 | try: |
| 74 | _transaction_depth.set(self._depth) |
no outgoing calls
no test coverage detected