MCPcopy
hub / github.com/hanzhanggit/StackGAN / CondGANTrainer

Class CondGANTrainer

stageI/trainer.py:28–452  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

26
27
28class CondGANTrainer(object):
29 def __init__(self,
30 model,
31 dataset=None,
32 exp_name="model",
33 ckt_logs_dir="ckt_logs",
34 ):
35 """
36 :type model: RegularizedGAN
37 """
38 self.model = model
39 self.dataset = dataset
40 self.exp_name = exp_name
41 self.log_dir = ckt_logs_dir
42 self.checkpoint_dir = ckt_logs_dir
43
44 self.batch_size = cfg.TRAIN.BATCH_SIZE
45 self.max_epoch = cfg.TRAIN.MAX_EPOCH
46 self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
47 self.model_path = cfg.TRAIN.PRETRAINED_MODEL
48
49 self.log_vars = []
50
51 def build_placeholder(self):
52 '''Helper function for init_opt'''
53 self.images = tf.placeholder(
54 tf.float32, [self.batch_size] + self.dataset.image_shape,
55 name='real_images')
56 self.wrong_images = tf.placeholder(
57 tf.float32, [self.batch_size] + self.dataset.image_shape,
58 name='wrong_images'
59 )
60 self.embeddings = tf.placeholder(
61 tf.float32, [self.batch_size] + self.dataset.embedding_shape,
62 name='conditional_embeddings'
63 )
64
65 self.generator_lr = tf.placeholder(
66 tf.float32, [],
67 name='generator_learning_rate'
68 )
69 self.discriminator_lr = tf.placeholder(
70 tf.float32, [],
71 name='discriminator_learning_rate'
72 )
73
74 def sample_encoded_context(self, embeddings):
75 '''Helper function for init_opt'''
76 c_mean_logsigma = self.model.generate_condition(embeddings)
77 mean = c_mean_logsigma[0]
78 if cfg.TRAIN.COND_AUGMENTATION:
79 # epsilon = tf.random_normal(tf.shape(mean))
80 epsilon = tf.truncated_normal(tf.shape(mean))
81 stddev = tf.exp(c_mean_logsigma[1])
82 c = mean + stddev * epsilon
83
84 kl_loss = KL_loss(c_mean_logsigma[0], c_mean_logsigma[1])
85 else:

Callers 1

run_exp.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected