Wraps an AsyncSession to spy on rollback() and optionally fail commit().
| 1129 | # on errors |
| 1130 | # --------------------------------------------------------------------------- |
| 1131 | class _RollbackSpySession: |
| 1132 | """Wraps an AsyncSession to spy on rollback() and optionally fail commit().""" |
| 1133 | |
| 1134 | def __init__(self, real_session, *, fail_commit=False): |
| 1135 | self._real = real_session |
| 1136 | self._fail_commit = fail_commit |
| 1137 | self.rollback_called = False |
| 1138 | |
| 1139 | async def __aenter__(self): |
| 1140 | self._real = await self._real.__aenter__() |
| 1141 | return self |
| 1142 | |
| 1143 | async def __aexit__(self, *args): |
| 1144 | return await self._real.__aexit__(*args) |
| 1145 | |
| 1146 | async def commit(self): |
| 1147 | if self._fail_commit: |
| 1148 | raise RuntimeError('simulated commit failure') |
| 1149 | return await self._real.commit() |
| 1150 | |
| 1151 | async def rollback(self): |
| 1152 | self.rollback_called = True |
| 1153 | return await self._real.rollback() |
| 1154 | |
| 1155 | def __getattr__(self, name): |
| 1156 | return getattr(self._real, name) |
| 1157 | |
| 1158 | |
| 1159 | @pytest.mark.asyncio |