MCPcopy Index your code
hub / github.com/tensorpack/tensorpack / resample

Function resample

examples/OpticalFlow/flownet_models.py:75–112  ·  view source on GitHub ↗
(img, flow)

Source from the content-addressed store, hash-verified

73
74
75def resample(img, flow):
76 # img, NCHW
77 # flow, N2HW
78 B = tf.shape(img)[0]
79 c = tf.shape(img)[1]
80 h = tf.shape(img)[2]
81 w = tf.shape(img)[3]
82 img_flat = tf.reshape(tf.transpose(img, [0, 2, 3, 1]), [-1, c])
83
84 dx, dy = tf.unstack(flow, axis=1)
85 xf, yf = tf.meshgrid(tf.cast(tf.range(w), tf.float32), tf.cast(tf.range(h), tf.float32))
86 xf = xf + dx
87 yf = yf + dy
88
89 alpha = tf.expand_dims(xf - tf.floor(xf), axis=-1)
90 beta = tf.expand_dims(yf - tf.floor(yf), axis=-1)
91
92 xL = tf.clip_by_value(tf.cast(tf.floor(xf), dtype=tf.int32), 0, w - 1)
93 xR = tf.clip_by_value(tf.cast(tf.floor(xf) + 1, dtype=tf.int32), 0, w - 1)
94 yT = tf.clip_by_value(tf.cast(tf.floor(yf), dtype=tf.int32), 0, h - 1)
95 yB = tf.clip_by_value(tf.cast(tf.floor(yf) + 1, dtype=tf.int32), 0, h - 1)
96
97 batch_ids = tf.tile(tf.expand_dims(tf.expand_dims(tf.range(B), axis=-1), axis=-1), [1, h, w])
98
99 def get(y, x):
100 idx = tf.reshape(batch_ids * h * w + y * w + x, [-1])
101 idx = tf.cast(idx, tf.int32)
102 return tf.gather(img_flat, idx)
103
104 val = tf.zeros_like(alpha)
105 val += (1 - alpha) * (1 - beta) * tf.reshape(get(yT, xL), [-1, h, w, c])
106 val += (0 + alpha) * (1 - beta) * tf.reshape(get(yT, xR), [-1, h, w, c])
107 val += (1 - alpha) * (0 + beta) * tf.reshape(get(yB, xL), [-1, h, w, c])
108 val += (0 + alpha) * (0 + beta) * tf.reshape(get(yB, xR), [-1, h, w, c])
109
110 # we need to enforce the channel_dim known during compile-time here
111 shp = img.shape.as_list()
112 return tf.reshape(tf.transpose(val, [0, 3, 1, 2]), [-1, shp[1], h, w])
113
114
115def resize(x, mode, factor=4):

Callers 1

graph_structureMethod · 0.85

Calls 2

getFunction · 0.85
shapeMethod · 0.80

Tested by

no test coverage detected