Builds BASNet input.
(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None)
| 116 | ckpt_dir_or_file) |
| 117 | |
| 118 | def build_inputs(self, |
| 119 | params: exp_cfg.DataConfig, |
| 120 | input_context: Optional[tf.distribute.InputContext] = None): |
| 121 | """Builds BASNet input.""" |
| 122 | |
| 123 | ignore_label = self.task_config.losses.ignore_label |
| 124 | |
| 125 | decoder = segmentation_input.Decoder() |
| 126 | parser = segmentation_input.Parser( |
| 127 | output_size=params.output_size, |
| 128 | crop_size=params.crop_size, |
| 129 | ignore_label=ignore_label, |
| 130 | aug_rand_hflip=params.aug_rand_hflip, |
| 131 | dtype=params.dtype) |
| 132 | |
| 133 | reader = input_reader.InputReader( |
| 134 | params, |
| 135 | dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), |
| 136 | decoder_fn=decoder.decode, |
| 137 | parser_fn=parser.parse_fn(params.is_training)) |
| 138 | |
| 139 | dataset = reader.read(input_context=input_context) |
| 140 | |
| 141 | return dataset |
| 142 | |
| 143 | def build_losses(self, label, model_outputs, aux_losses=None): |
| 144 | """Hybrid loss proposed in BASNet. |