| 18 | |
| 19 | |
| 20 | class Model(tpu_utils.Model): |
| 21 | def __init__(self, *, model_name, betas: np.ndarray, loss_type: str, num_classes: int, |
| 22 | dropout: float, randflip, block_size: int): |
| 23 | self.model_name = model_name |
| 24 | self.diffusion = GaussianDiffusion(betas=betas, loss_type=loss_type) |
| 25 | self.num_classes = num_classes |
| 26 | self.dropout = dropout |
| 27 | self.randflip = randflip |
| 28 | self.block_size = block_size |
| 29 | |
| 30 | def _denoise(self, x, t, y, dropout): |
| 31 | B, H, W, C = x.shape.as_list() |
| 32 | assert x.dtype == tf.float32 |
| 33 | assert t.shape == [B] and t.dtype in [tf.int32, tf.int64] |
| 34 | assert y.shape == [B] and y.dtype in [tf.int32, tf.int64] |
| 35 | orig_out_ch = out_ch = C |
| 36 | |
| 37 | if self.block_size != 1: # this can be used to reduce memory consumption |
| 38 | x = tf.nn.space_to_depth(x, self.block_size) |
| 39 | out_ch *= self.block_size ** 2 |
| 40 | |
| 41 | y = None |
| 42 | if self.model_name == 'unet2d16b2c112244': # 114M for block_size=1 |
| 43 | out = unet.model( |
| 44 | x, t=t, y=y, name='model', ch=128, ch_mult=(1, 1, 2, 2, 4, 4), num_res_blocks=2, attn_resolutions=(16,), |
| 45 | out_ch=out_ch, num_classes=self.num_classes, dropout=dropout |
| 46 | ) |
| 47 | else: |
| 48 | raise NotImplementedError(self.model_name) |
| 49 | |
| 50 | if self.block_size != 1: |
| 51 | out = tf.nn.depth_to_space(out, self.block_size) |
| 52 | assert out.shape == [B, H, W, orig_out_ch] |
| 53 | return out |
| 54 | |
| 55 | def train_fn(self, x, y): |
| 56 | B, H, W, C = x.shape |
| 57 | if self.randflip: |
| 58 | x = tf.image.random_flip_left_right(x) |
| 59 | assert x.shape == [B, H, W, C] |
| 60 | t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32) |
| 61 | losses = self.diffusion.p_losses( |
| 62 | denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t) |
| 63 | assert losses.shape == t.shape == [B] |
| 64 | return {'loss': tf.reduce_mean(losses)} |
| 65 | |
| 66 | def samples_fn(self, dummy_noise, y): |
| 67 | return { |
| 68 | 'samples': self.diffusion.p_sample_loop( |
| 69 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), |
| 70 | shape=dummy_noise.shape.as_list(), |
| 71 | noise_fn=tf.random_normal |
| 72 | ) |
| 73 | } |
| 74 | |
| 75 | def samples_fn_denoising_trajectory(self, dummy_noise, y, repeat_noise_steps=0): |
| 76 | times, imgs = self.diffusion.p_sample_loop_trajectory( |
| 77 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), |
no outgoing calls
no test coverage detected