MCPcopy Index your code
hub / github.com/albertpumarola/GANimation / forward

Method forward

models/ganimation.py:123–196  ·  view source on GitHub ↗
(self, keep_data_for_visuals=False, return_estimates=False)

Source from the content-addressed store, hash-verified

121 return OrderedDict([('real_img', self._input_real_img_path)])
122
123 def forward(self, keep_data_for_visuals=False, return_estimates=False):
124 if not self._is_train:
125 # convert tensor to variables
126 real_img = Variable(self._input_real_img, volatile=True)
127 real_cond = Variable(self._input_real_cond, volatile=True)
128 desired_cond = Variable(self._input_desired_cond, volatile=True)
129
130 # generate fake images
131 fake_imgs, fake_img_mask = self._G.forward(real_img, desired_cond)
132 fake_img_mask = self._do_if_necessary_saturate_mask(fake_img_mask, saturate=self._opt.do_saturate_mask)
133 fake_imgs_masked = fake_img_mask * real_img + (1 - fake_img_mask) * fake_imgs
134
135 rec_real_img_rgb, rec_real_img_mask = self._G.forward(fake_imgs_masked, real_cond)
136 rec_real_img_mask = self._do_if_necessary_saturate_mask(rec_real_img_mask, saturate=self._opt.do_saturate_mask)
137 rec_real_imgs = rec_real_img_mask * fake_imgs_masked + (1 - rec_real_img_mask) * rec_real_img_rgb
138
139 imgs = None
140 data = None
141 if return_estimates:
142 # normalize mask for better visualization
143 fake_img_mask_max = fake_imgs_masked.view(fake_img_mask.size(0), -1).max(-1)[0]
144 fake_img_mask_max = torch.unsqueeze(torch.unsqueeze(torch.unsqueeze(fake_img_mask_max, -1), -1), -1)
145 # fake_img_mask_norm = fake_img_mask / fake_img_mask_max
146 fake_img_mask_norm = fake_img_mask
147
148 # generate images
149 im_real_img = util.tensor2im(real_img.data)
150 im_fake_imgs = util.tensor2im(fake_imgs.data)
151 im_fake_img_mask_norm = util.tensor2maskim(fake_img_mask_norm.data)
152 im_fake_imgs_masked = util.tensor2im(fake_imgs_masked.data)
153 im_rec_imgs = util.tensor2im(rec_real_img_rgb.data)
154 im_rec_img_mask_norm = util.tensor2maskim(rec_real_img_mask.data)
155 im_rec_imgs_masked = util.tensor2im(rec_real_imgs.data)
156 im_concat_img = np.concatenate([im_real_img, im_fake_imgs_masked, im_fake_img_mask_norm, im_fake_imgs,
157 im_rec_imgs, im_rec_img_mask_norm, im_rec_imgs_masked],
158 1)
159
160 im_real_img_batch = util.tensor2im(real_img.data, idx=-1, nrows=1)
161 im_fake_imgs_batch = util.tensor2im(fake_imgs.data, idx=-1, nrows=1)
162 im_fake_img_mask_norm_batch = util.tensor2maskim(fake_img_mask_norm.data, idx=-1, nrows=1)
163 im_fake_imgs_masked_batch = util.tensor2im(fake_imgs_masked.data, idx=-1, nrows=1)
164 im_concat_img_batch = np.concatenate([im_real_img_batch, im_fake_imgs_masked_batch,
165 im_fake_img_mask_norm_batch, im_fake_imgs_batch],
166 1)
167
168 imgs = OrderedDict([('real_img', im_real_img),
169 ('fake_imgs', im_fake_imgs),
170 ('fake_img_mask', im_fake_img_mask_norm),
171 ('fake_imgs_masked', im_fake_imgs_masked),
172 ('concat', im_concat_img),
173 ('real_img_batch', im_real_img_batch),
174 ('fake_imgs_batch', im_fake_imgs_batch),
175 ('fake_img_mask_batch', im_fake_img_mask_norm_batch),
176 ('fake_imgs_masked_batch', im_fake_imgs_masked_batch),
177 ('concat_batch', im_concat_img_batch),
178 ])
179
180 data = OrderedDict([('real_path', self._input_real_img_path),

Callers 2

_forward_GMethod · 0.45
_forward_DMethod · 0.45

Calls 1

Tested by

no test coverage detected