(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0)
| 402 | print(f"{self.__class__.__name__} is using checkpointing") |
| 403 | |
| 404 | def forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): |
| 405 | kwargs = {"x": x} |
| 406 | |
| 407 | if context is not None: |
| 408 | kwargs.update({"context": context}) |
| 409 | |
| 410 | if additional_tokens is not None: |
| 411 | kwargs.update({"additional_tokens": additional_tokens}) |
| 412 | |
| 413 | if n_times_crossframe_attn_in_self: |
| 414 | kwargs.update({"n_times_crossframe_attn_in_self": n_times_crossframe_attn_in_self}) |
| 415 | |
| 416 | # return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint) |
| 417 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) |
| 418 | |
| 419 | def _forward(self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): |
| 420 | x = ( |
nothing calls this directly
no test coverage detected