MCPcopy
hub / github.com/hojonathanho/diffusion / Model

Class Model

scripts/run_lsun.py:29–82  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

27
28
29class Model(tpu_utils.Model):
30 def __init__(self, *, model_name, betas: np.ndarray, loss_type: str, num_classes: int,
31 dropout: float, randflip, block_size: int):
32 self.model_name = model_name
33 self.diffusion = GaussianDiffusion(betas=betas, loss_type=loss_type)
34 self.num_classes = num_classes
35 self.dropout = dropout
36 self.randflip = randflip
37 self.block_size = block_size
38
39 def _denoise(self, x, t, y, dropout):
40 B, H, W, C = x.shape.as_list()
41 assert x.dtype == tf.float32
42 assert t.shape == [B] and t.dtype in [tf.int32, tf.int64]
43 assert y.shape == [B] and y.dtype in [tf.int32, tf.int64]
44 orig_out_ch = out_ch = C
45
46 if self.block_size != 1:
47 x = tf.nn.space_to_depth(x, self.block_size)
48 out_ch *= self.block_size ** 2
49
50 y = None
51 if self.model_name == 'unet2d16b2c112244': # 114M for block_size=1
52 out = unet.model(
53 x, t=t, y=y, name='model', ch=128, ch_mult=(1, 1, 2, 2, 4, 4), num_res_blocks=2, attn_resolutions=(16,),
54 out_ch=out_ch, num_classes=self.num_classes, dropout=dropout
55 )
56 else:
57 raise NotImplementedError(self.model_name)
58
59 if self.block_size != 1:
60 out = tf.nn.depth_to_space(out, self.block_size)
61 assert out.shape == [B, H, W, orig_out_ch]
62 return out
63
64 def train_fn(self, x, y):
65 B, H, W, C = x.shape
66 if self.randflip:
67 x = tf.image.random_flip_left_right(x)
68 assert x.shape == [B, H, W, C]
69 t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32)
70 losses = self.diffusion.p_losses(
71 denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t)
72 assert losses.shape == t.shape == [B]
73 return {'loss': tf.reduce_mean(losses)}
74
75 def samples_fn(self, dummy_noise, y):
76 return {
77 'samples': self.diffusion.p_sample_loop(
78 denoise_fn=functools.partial(self._denoise, y=y, dropout=0),
79 shape=dummy_noise.shape.as_list(),
80 noise_fn=tf.random_normal
81 )
82 }
83
84
85def evaluation(

Callers 2

evaluationFunction · 0.70
trainFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected