MCPcopy
hub / github.com/hanzhanggit/StackGAN / build_model

Function build_model

demo/demo.py:52–76  ·  view source on GitHub ↗
(sess, embedding_dim, batch_size)

Source from the content-addressed store, hash-verified

50
51
52def build_model(sess, embedding_dim, batch_size):
53 model = CondGAN(
54 lr_imsize=cfg.TEST.LR_IMSIZE,
55 hr_lr_ratio=int(cfg.TEST.HR_IMSIZE/cfg.TEST.LR_IMSIZE))
56
57 embeddings = tf.placeholder(
58 tf.float32, [batch_size, embedding_dim],
59 name='conditional_embeddings')
60 with pt.defaults_scope(phase=pt.Phase.test):
61 with tf.variable_scope("g_net"):
62 c = sample_encoded_context(embeddings, model)
63 z = tf.random_normal([batch_size, cfg.Z_DIM])
64 fake_images = model.get_generator(tf.concat(1, [c, z]))
65 with tf.variable_scope("hr_g_net"):
66 hr_c = sample_encoded_context(embeddings, model)
67 hr_fake_images = model.hr_get_generator(fake_images, hr_c)
68
69 ckt_path = cfg.TEST.PRETRAINED_MODEL
70 if ckt_path.find('.ckpt') != -1:
71 print("Reading model parameters from %s" % ckt_path)
72 saver = tf.train.Saver(tf.all_variables())
73 saver.restore(sess, ckt_path)
74 else:
75 print("Input a valid model path.")
76 return embeddings, fake_images, hr_fake_images
77
78
79def drawCaption(img, caption):

Callers 1

demo.pyFile · 0.70

Calls 4

get_generatorMethod · 0.95
hr_get_generatorMethod · 0.95
CondGANClass · 0.90
sample_encoded_contextFunction · 0.70

Tested by

no test coverage detected