MCPcopy
hub / github.com/hanzhanggit/StackGAN / CondGANTrainer

Class CondGANTrainer

stageII/trainer.py:29–669  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

27
28
29class CondGANTrainer(object):
30 def __init__(self,
31 model,
32 dataset=None,
33 exp_name="model",
34 ckt_logs_dir="ckt_logs",
35 ):
36 """
37 :type model: RegularizedGAN
38 """
39 self.model = model
40 self.dataset = dataset
41 self.exp_name = exp_name
42 self.log_dir = ckt_logs_dir
43 self.checkpoint_dir = ckt_logs_dir
44
45 self.batch_size = cfg.TRAIN.BATCH_SIZE
46 self.max_epoch = cfg.TRAIN.MAX_EPOCH
47 self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
48 self.model_path = cfg.TRAIN.PRETRAINED_MODEL
49
50 self.log_vars = []
51
52 self.hr_image_shape = self.dataset.image_shape
53 ratio = self.dataset.hr_lr_ratio
54 self.lr_image_shape = [int(self.hr_image_shape[0] / ratio),
55 int(self.hr_image_shape[1] / ratio),
56 self.hr_image_shape[2]]
57 print('hr_image_shape', self.hr_image_shape)
58 print('lr_image_shape', self.lr_image_shape)
59
60 def build_placeholder(self):
61 '''Helper function for init_opt'''
62 self.hr_images = tf.placeholder(
63 tf.float32, [self.batch_size] + self.hr_image_shape,
64 name='real_hr_images')
65 self.hr_wrong_images = tf.placeholder(
66 tf.float32, [self.batch_size] + self.hr_image_shape,
67 name='wrong_hr_images'
68 )
69 self.embeddings = tf.placeholder(
70 tf.float32, [self.batch_size] + self.dataset.embedding_shape,
71 name='conditional_embeddings'
72 )
73
74 self.generator_lr = tf.placeholder(
75 tf.float32, [],
76 name='generator_learning_rate'
77 )
78 self.discriminator_lr = tf.placeholder(
79 tf.float32, [],
80 name='discriminator_learning_rate'
81 )
82 #
83 self.images = tf.image.resize_bilinear(self.hr_images,
84 self.lr_image_shape[:2])
85 self.wrong_images = tf.image.resize_bilinear(self.hr_wrong_images,
86 self.lr_image_shape[:2])

Callers 1

run_exp.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected