| 27 | |
| 28 | |
| 29 | class Model(tpu_utils.Model): |
| 30 | def __init__(self, *, model_name, betas: np.ndarray, loss_type: str, num_classes: int, |
| 31 | dropout: float, randflip, block_size: int): |
| 32 | self.model_name = model_name |
| 33 | self.diffusion = GaussianDiffusion(betas=betas, loss_type=loss_type) |
| 34 | self.num_classes = num_classes |
| 35 | self.dropout = dropout |
| 36 | self.randflip = randflip |
| 37 | self.block_size = block_size |
| 38 | |
| 39 | def _denoise(self, x, t, y, dropout): |
| 40 | B, H, W, C = x.shape.as_list() |
| 41 | assert x.dtype == tf.float32 |
| 42 | assert t.shape == [B] and t.dtype in [tf.int32, tf.int64] |
| 43 | assert y.shape == [B] and y.dtype in [tf.int32, tf.int64] |
| 44 | orig_out_ch = out_ch = C |
| 45 | |
| 46 | if self.block_size != 1: |
| 47 | x = tf.nn.space_to_depth(x, self.block_size) |
| 48 | out_ch *= self.block_size ** 2 |
| 49 | |
| 50 | y = None |
| 51 | if self.model_name == 'unet2d16b2c112244': # 114M for block_size=1 |
| 52 | out = unet.model( |
| 53 | x, t=t, y=y, name='model', ch=128, ch_mult=(1, 1, 2, 2, 4, 4), num_res_blocks=2, attn_resolutions=(16,), |
| 54 | out_ch=out_ch, num_classes=self.num_classes, dropout=dropout |
| 55 | ) |
| 56 | else: |
| 57 | raise NotImplementedError(self.model_name) |
| 58 | |
| 59 | if self.block_size != 1: |
| 60 | out = tf.nn.depth_to_space(out, self.block_size) |
| 61 | assert out.shape == [B, H, W, orig_out_ch] |
| 62 | return out |
| 63 | |
| 64 | def train_fn(self, x, y): |
| 65 | B, H, W, C = x.shape |
| 66 | if self.randflip: |
| 67 | x = tf.image.random_flip_left_right(x) |
| 68 | assert x.shape == [B, H, W, C] |
| 69 | t = tf.random_uniform([B], 0, self.diffusion.num_timesteps, dtype=tf.int32) |
| 70 | losses = self.diffusion.p_losses( |
| 71 | denoise_fn=functools.partial(self._denoise, y=y, dropout=self.dropout), x_start=x, t=t) |
| 72 | assert losses.shape == t.shape == [B] |
| 73 | return {'loss': tf.reduce_mean(losses)} |
| 74 | |
| 75 | def samples_fn(self, dummy_noise, y): |
| 76 | return { |
| 77 | 'samples': self.diffusion.p_sample_loop( |
| 78 | denoise_fn=functools.partial(self._denoise, y=y, dropout=0), |
| 79 | shape=dummy_noise.shape.as_list(), |
| 80 | noise_fn=tf.random_normal |
| 81 | ) |
| 82 | } |
| 83 | |
| 84 | |
| 85 | def evaluation( |
no outgoing calls
no test coverage detected