MCPcopy
hub / github.com/hojonathanho/diffusion / Model

Class Model

scripts/run_celebahq.py:20–100  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

18
19
20class Model(tpu_utils.Model):
21 def __init__(self, *, model_name, betas: np.ndarray, loss_type: str, num_classes: int,
22 dropout: float, randflip, block_size: int):
23 self.model_name = model_name
24 self.diffusion = GaussianDiffusion(betas=betas, loss_type=loss_type)
25 self.num_classes = num_classes
26 self.dropout = dropout
27 self.randflip = randflip
28 self.block_size = block_size
29
30 def _denoise(self, x, t, y, dropout):
31 B, H, W, C = x.shape.as_list()
32 assert x.dtype == tf.float32
33 assert t.shape == [B] and t.dtype in [tf.int32, tf.int64]
34 assert y.shape == [B] and y.dtype in [tf.int32, tf.int64]
35 orig_out_ch = out_ch = C
36
37 if self.block_size != 1: # this can be used to reduce memory consumption
38 x = tf.nn.space_to_depth(x, self.block_size)
39 out_ch *= self.block_size ** 2
40
41 y = None
42 if self.model_name == 'unet2d16b2c112244': # 114M for block_size=1
43 out = unet.model(
44 x, t=t, y=y, name='model', ch=128, ch_mult=(1, 1, 2, 2, 4, 4), num_res_blocks=2, attn_resolutions=(16,),
45 out_ch=out_ch, num_classes=self.num_classes, dropout=dropout
46 )
47 else:
48 raise NotImplementedError(self.model_name)
49
50 if self.block_size != 1:
51 out = tf.nn.depth_to_space(out, self.block_size)
52 assert out.shape == [B, H, W, orig_out_ch]
53 return out
54
55 def train_fn(self, x, y):
56 B, H, W, C = x.shape
57 if self.randflip:
58 x = tf.image.random_flip_left_right(x)
59 assert x.shape == [B, H, W, C]
60 t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32)
61 losses = self.diffusion.p_losses(
62 denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t)
63 assert losses.shape == t.shape == [B]
64 return {'loss': tf.reduce_mean(losses)}
65
66 def samples_fn(self, dummy_noise, y):
67 return {
68 'samples': self.diffusion.p_sample_loop(
69 denoise_fn=functools.partial(self._denoise, y=y, dropout=0),
70 shape=dummy_noise.shape.as_list(),
71 noise_fn=tf.random_normal
72 )
73 }
74
75 def samples_fn_denoising_trajectory(self, dummy_noise, y, repeat_noise_steps=0):
76 times, imgs = self.diffusion.p_sample_loop_trajectory(
77 denoise_fn=functools.partial(self._denoise, y=y, dropout=0),

Callers 2

evaluationFunction · 0.70
trainFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected