| 3 | from ppq.core.quant import TargetPlatform |
| 4 | |
| 5 | class MyModel(torch.nn.Module): |
| 6 | def __init__(self) -> None: |
| 7 | super().__init__() |
| 8 | self.gemm_1 = torch.nn.Linear(in_features=10, out_features=10) |
| 9 | self.gemm_2 = torch.nn.Linear(in_features=10, out_features=10) |
| 10 | self.gemm_3 = torch.nn.Linear(in_features=10, out_features=10) |
| 11 | self.gemm_4 = torch.nn.Linear(in_features=10, out_features=10) |
| 12 | self.gemm_5 = torch.nn.Linear(in_features=10, out_features=10) |
| 13 | self.gemm_6 = torch.nn.Linear(in_features=10, out_features=10) |
| 14 | self.gemm_7 = torch.nn.Linear(in_features=10, out_features=10) |
| 15 | self.gemm_8 = torch.nn.Linear(in_features=10, out_features=10) |
| 16 | self.gemm_9 = torch.nn.Linear(in_features=10, out_features=10) |
| 17 | self.gemm_10 = torch.nn.Linear(in_features=10, out_features=10) |
| 18 | self.gemm_J = torch.nn.Linear(in_features=10, out_features=10) |
| 19 | self.gemm_Q = torch.nn.Linear(in_features=10, out_features=10) |
| 20 | self.gemm_K = torch.nn.Linear(in_features=10, out_features=10) |
| 21 | self.gemm_A = torch.nn.Linear(in_features=10, out_features=10) |
| 22 | |
| 23 | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 24 | x = self.gemm_1(x) |
| 25 | x = torch.relu(x) |
| 26 | |
| 27 | x2 = torch.relu(self.gemm_2(x)) |
| 28 | x3 = torch.relu(self.gemm_3(x)) |
| 29 | x4 = torch.relu(self.gemm_4(x)) |
| 30 | x5 = torch.relu(self.gemm_5(x)) |
| 31 | x6 = torch.relu(self.gemm_6(x)) |
| 32 | |
| 33 | x2 = self.gemm_7(x2) |
| 34 | x3 = self.gemm_8(x3) |
| 35 | x4 = self.gemm_9(x4) |
| 36 | x5 = self.gemm_10(x5) |
| 37 | x6 = self.gemm_J(x6) |
| 38 | |
| 39 | x7 = torch.relu(self.gemm_Q(x)) |
| 40 | x7 = self.gemm_K(x7) |
| 41 | |
| 42 | x8 = torch.max_pool1d(x7, kernel_size=2) |
| 43 | return torch.cat([x2, x3, x4, x5, x6, x7, x8], dim=-1) |
| 44 | |
| 45 | model = MyModel().cuda() |
| 46 | model.forward(torch.zeros(size=[10, 10]).cuda()) |