(img, flow)
| 73 | |
| 74 | |
| 75 | def resample(img, flow): |
| 76 | # img, NCHW |
| 77 | # flow, N2HW |
| 78 | B = tf.shape(img)[0] |
| 79 | c = tf.shape(img)[1] |
| 80 | h = tf.shape(img)[2] |
| 81 | w = tf.shape(img)[3] |
| 82 | img_flat = tf.reshape(tf.transpose(img, [0, 2, 3, 1]), [-1, c]) |
| 83 | |
| 84 | dx, dy = tf.unstack(flow, axis=1) |
| 85 | xf, yf = tf.meshgrid(tf.cast(tf.range(w), tf.float32), tf.cast(tf.range(h), tf.float32)) |
| 86 | xf = xf + dx |
| 87 | yf = yf + dy |
| 88 | |
| 89 | alpha = tf.expand_dims(xf - tf.floor(xf), axis=-1) |
| 90 | beta = tf.expand_dims(yf - tf.floor(yf), axis=-1) |
| 91 | |
| 92 | xL = tf.clip_by_value(tf.cast(tf.floor(xf), dtype=tf.int32), 0, w - 1) |
| 93 | xR = tf.clip_by_value(tf.cast(tf.floor(xf) + 1, dtype=tf.int32), 0, w - 1) |
| 94 | yT = tf.clip_by_value(tf.cast(tf.floor(yf), dtype=tf.int32), 0, h - 1) |
| 95 | yB = tf.clip_by_value(tf.cast(tf.floor(yf) + 1, dtype=tf.int32), 0, h - 1) |
| 96 | |
| 97 | batch_ids = tf.tile(tf.expand_dims(tf.expand_dims(tf.range(B), axis=-1), axis=-1), [1, h, w]) |
| 98 | |
| 99 | def get(y, x): |
| 100 | idx = tf.reshape(batch_ids * h * w + y * w + x, [-1]) |
| 101 | idx = tf.cast(idx, tf.int32) |
| 102 | return tf.gather(img_flat, idx) |
| 103 | |
| 104 | val = tf.zeros_like(alpha) |
| 105 | val += (1 - alpha) * (1 - beta) * tf.reshape(get(yT, xL), [-1, h, w, c]) |
| 106 | val += (0 + alpha) * (1 - beta) * tf.reshape(get(yT, xR), [-1, h, w, c]) |
| 107 | val += (1 - alpha) * (0 + beta) * tf.reshape(get(yB, xL), [-1, h, w, c]) |
| 108 | val += (0 + alpha) * (0 + beta) * tf.reshape(get(yB, xR), [-1, h, w, c]) |
| 109 | |
| 110 | # we need to enforce the channel_dim known during compile-time here |
| 111 | shp = img.shape.as_list() |
| 112 | return tf.reshape(tf.transpose(val, [0, 3, 1, 2]), [-1, shp[1], h, w]) |
| 113 | |
| 114 | |
| 115 | def resize(x, mode, factor=4): |
no test coverage detected