| 29 | |
| 30 | |
| 31 | class SimpleModel(nn.Module): |
| 32 | def __init__(self, num_blocks=16): |
| 33 | super().__init__() |
| 34 | self.blocks = nn.ModuleList([ |
| 35 | TransformerBlock(i) |
| 36 | for i in range(num_blocks)]) |
| 37 | |
| 38 | def forward(self, x): |
| 39 | for block in self.blocks: |
| 40 | x = block(x) |
| 41 | return x |
| 42 | |
| 43 | @property |
| 44 | def device(self): |
| 45 | return next(self.parameters()).device |
| 46 | |
| 47 | |
| 48 | # Device Synchronization Tests |
no outgoing calls