MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/models/stdit/model.py:812–834  ·  view source on GitHub ↗
(self, x)

Source from the content-addressed store, hash-verified

810 self.quant_mode = quant_mode
811
812 def forward(self, x):
813 if not USE_STATIC_SHAPE:
814 raise NotImplementedError('Only static shape is supported')
815 _, _, D, H, W = x.shape
816 if W % self.patch_size[2] != 0:
817 x = pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
818 if H % self.patch_size[1] != 0:
819 x = pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
820 if D % self.patch_size[0] != 0:
821 x = pad(
822 x, (0, 0, 0, 0, 0, self.patch_size[0] - D % self.patch_size[0]))
823 x = self.proj(x) # (B C T H W)
824 if self.norm is not None:
825 D = shape(x, 2)
826 Wh = shape(x, 3)
827 Ww = shape(x, 4)
828 x = x.flatten(2).transpose(1, 2)
829 x = self.norm(x)
830 x = x.transpose(1, 2).view([-1, self.embed_dim, D, Wh, Ww])
831 if self.flatten:
832 x = x.flatten(2).transpose(1, 2) # BCTHW -> BNC
833 self.register_network_output('output', x)
834 return x
835
836
837class STDiT3Block(Module):

Callers

nothing calls this directly

Calls 6

transposeMethod · 0.80
flattenMethod · 0.80
padFunction · 0.50
shapeFunction · 0.50
viewMethod · 0.45

Tested by

no test coverage detected