MCPcopy Index your code
hub / github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT / get_sample_input

Method get_sample_input

model_helper.py:86–115  ·  view source on GitHub ↗
(
        self,
        batch_size: int,
        latent_height: int,
        latent_width: int,
        text_len: int,
        device: str = "cuda",
        dtype: torch.dtype = torch.float32,
    )

Source from the content-addressed store, hash-verified

84 return dyn_axes
85
86 def get_sample_input(
87 self,
88 batch_size: int,
89 latent_height: int,
90 latent_width: int,
91 text_len: int,
92 device: str = "cuda",
93 dtype: torch.dtype = torch.float32,
94 ) -> Tuple[torch.Tensor]:
95 return (
96 torch.randn(
97 batch_size,
98 self.in_channels,
99 latent_height,
100 latent_width,
101 dtype=dtype,
102 device=device,
103 ),
104 torch.randn(batch_size, dtype=dtype, device=device),
105 torch.randn(
106 batch_size,
107 text_len,
108 self.embedding_dim,
109 dtype=dtype,
110 device=device,
111 ),
112 torch.randn(batch_size, self.num_xl_classes, dtype=dtype, device=device)
113 if self.is_xl
114 else None,
115 )
116
117 def get_input_profile(self, profile: ProfileSettings) -> dict:
118 min_batch, opt_batch, max_batch = profile.get_a1111_batch_dim()

Callers 2

export_loraFunction · 0.80
export_onnxFunction · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected