MCPcopy
hub / github.com/OpenPPL/ppq / MyModel

Class MyModel

tests/test_block_split.py:5–43  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

3from ppq.core.quant import TargetPlatform
4
5class MyModel(torch.nn.Module):
6 def __init__(self) -> None:
7 super().__init__()
8 self.gemm_1 = torch.nn.Linear(in_features=10, out_features=10)
9 self.gemm_2 = torch.nn.Linear(in_features=10, out_features=10)
10 self.gemm_3 = torch.nn.Linear(in_features=10, out_features=10)
11 self.gemm_4 = torch.nn.Linear(in_features=10, out_features=10)
12 self.gemm_5 = torch.nn.Linear(in_features=10, out_features=10)
13 self.gemm_6 = torch.nn.Linear(in_features=10, out_features=10)
14 self.gemm_7 = torch.nn.Linear(in_features=10, out_features=10)
15 self.gemm_8 = torch.nn.Linear(in_features=10, out_features=10)
16 self.gemm_9 = torch.nn.Linear(in_features=10, out_features=10)
17 self.gemm_10 = torch.nn.Linear(in_features=10, out_features=10)
18 self.gemm_J = torch.nn.Linear(in_features=10, out_features=10)
19 self.gemm_Q = torch.nn.Linear(in_features=10, out_features=10)
20 self.gemm_K = torch.nn.Linear(in_features=10, out_features=10)
21 self.gemm_A = torch.nn.Linear(in_features=10, out_features=10)
22
23 def forward(self, x: torch.Tensor) -> torch.Tensor:
24 x = self.gemm_1(x)
25 x = torch.relu(x)
26
27 x2 = torch.relu(self.gemm_2(x))
28 x3 = torch.relu(self.gemm_3(x))
29 x4 = torch.relu(self.gemm_4(x))
30 x5 = torch.relu(self.gemm_5(x))
31 x6 = torch.relu(self.gemm_6(x))
32
33 x2 = self.gemm_7(x2)
34 x3 = self.gemm_8(x3)
35 x4 = self.gemm_9(x4)
36 x5 = self.gemm_10(x5)
37 x6 = self.gemm_J(x6)
38
39 x7 = torch.relu(self.gemm_Q(x))
40 x7 = self.gemm_K(x7)
41
42 x8 = torch.max_pool1d(x7, kernel_size=2)
43 return torch.cat([x2, x3, x4, x5, x6, x7, x8], dim=-1)
44
45model = MyModel().cuda()
46model.forward(torch.zeros(size=[10, 10]).cuda())

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected