r"""Produce output given segmentation and other conditioning inputs. random style will be used if neither z nor style_img is provided. Args: label (N x C x H x W tensor): One-hot segmentation mask of shape. z: Style vector. style_img: Style image.
(self, label, z=None, style_img=None)
| 43 | print('[GauGANLoader] GauGAN loading complete.') |
| 44 | |
| 45 | def eval(self, label, z=None, style_img=None): |
| 46 | r"""Produce output given segmentation and other conditioning inputs. |
| 47 | random style will be used if neither z nor style_img is provided. |
| 48 | |
| 49 | Args: |
| 50 | label (N x C x H x W tensor): One-hot segmentation mask of shape. |
| 51 | z: Style vector. |
| 52 | style_img: Style image. |
| 53 | """ |
| 54 | inputs = {'label': label[:, :-1].detach().half()} |
| 55 | random_style = True |
| 56 | |
| 57 | if z is not None: |
| 58 | random_style = False |
| 59 | inputs['z'] = z.detach().half() |
| 60 | elif style_img is not None: |
| 61 | random_style = False |
| 62 | inputs['images'] = style_img.detach().half() |
| 63 | |
| 64 | net_GG_output = self.net_GG(inputs, random_style=random_style) |
| 65 | |
| 66 | return net_GG_output['fake_images'] |
| 67 | |
| 68 | |
| 69 | class Trainer(BaseTrainer): |
no outgoing calls