MCPcopy
hub / github.com/zai-org/CogVideo / sample

Method sample

sat/diffusion_video.py:213–246  ·  view source on GitHub ↗
(
        self,
        cond: Dict,
        uc: Union[Dict, None] = None,
        batch_size: int = 16,
        shape: Union[None, Tuple, List] = None,
        prefix=None,
        concat_images=None,
        **kwargs,
    )

Source from the content-addressed store, hash-verified

211
212 @torch.no_grad()
213 def sample(
214 self,
215 cond: Dict,
216 uc: Union[Dict, None] = None,
217 batch_size: int = 16,
218 shape: Union[None, Tuple, List] = None,
219 prefix=None,
220 concat_images=None,
221 **kwargs,
222 ):
223 randn = torch.randn(batch_size, *shape).to(torch.float32).to(self.device)
224 if hasattr(self, "seeded_noise"):
225 randn = self.seeded_noise(randn)
226
227 if prefix is not None:
228 randn = torch.cat([prefix, randn[:, prefix.shape[1] :]], dim=1)
229
230 # broadcast noise
231 mp_size = mpu.get_model_parallel_world_size()
232 if mp_size > 1:
233 global_rank = torch.distributed.get_rank() // mp_size
234 src = global_rank * mp_size
235 torch.distributed.broadcast(randn, src=src, group=mpu.get_model_parallel_group())
236
237 scale = None
238 scale_emb = None
239
240 denoiser = lambda input, sigma, c, **addtional_model_inputs: self.denoiser(
241 self.model, input, sigma, c, concat_images=concat_images, **addtional_model_inputs
242 )
243
244 samples = self.sampler(denoiser, randn, cond, uc=uc, scale=scale, scale_emb=scale_emb)
245 samples = samples.to(self.dtype)
246 return samples
247
248 @torch.no_grad()
249 def log_conditionings(self, batch: Dict, n: int) -> Dict:

Callers 4

log_videoMethod · 0.95
encode_videoFunction · 0.45
collate_fnFunction · 0.45
collate_fnFunction · 0.45

Calls

no outgoing calls

Tested by

no test coverage detected