MCPcopy
hub / github.com/Vchitect/Latte / sample

Method sample

diffusion/timestep_sampler.py:44–59  ·  view source on GitHub ↗

Importance-sample timesteps for a batch. :param batch_size: the number of timesteps. :param device: the torch device to save to. :return: a tuple (timesteps, weights): - timesteps: a tensor of timestep indices. - weights: a ten

(self, batch_size, device)

Source from the content-addressed store, hash-verified

42 """
43
44 def sample(self, batch_size, device):
45 """
46 Importance-sample timesteps for a batch.
47 :param batch_size: the number of timesteps.
48 :param device: the torch device to save to.
49 :return: a tuple (timesteps, weights):
50 - timesteps: a tensor of timestep indices.
51 - weights: a tensor of weights to scale the resulting losses.
52 """
53 w = self.weights()
54 p = w / np.sum(w)
55 indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 indices = th.from_numpy(indices_np).long().to(device)
57 weights_np = 1 / (len(p) * p[indices_np])
58 weights = th.from_numpy(weights_np).float().to(device)
59 return indices, weights
60
61
62class UniformSampler(ScheduleSampler):

Callers 6

mainFunction · 0.80
mainFunction · 0.80
training_stepMethod · 0.80
training_stepMethod · 0.80
random_frame_samplingFunction · 0.80

Calls 1

weightsMethod · 0.95

Tested by

no test coverage detected