| 31 | |
| 32 | |
| 33 | class StaticDynamicShape(object): |
| 34 | def __init__(self, tensor): |
| 35 | assert isinstance(tensor, tf.Tensor), tensor |
| 36 | ndims = tensor.shape.ndims |
| 37 | self.static = tensor.shape.as_list() |
| 38 | if tensor.shape.is_fully_defined(): |
| 39 | self.dynamic = self.static[:] |
| 40 | else: |
| 41 | dynamic = tf.shape(tensor) |
| 42 | self.dynamic = [DynamicLazyAxis(dynamic, k) for k in range(ndims)] |
| 43 | |
| 44 | for k in range(ndims): |
| 45 | if self.static[k] is not None: |
| 46 | self.dynamic[k] = StaticLazyAxis(self.static[k]) |
| 47 | |
| 48 | def apply(self, axis, f): |
| 49 | if self.static[axis] is not None: |
| 50 | try: |
| 51 | st = f(self.static[axis]) |
| 52 | self.static[axis] = st |
| 53 | self.dynamic[axis] = StaticLazyAxis(st) |
| 54 | return |
| 55 | except TypeError: |
| 56 | pass |
| 57 | self.static[axis] = None |
| 58 | dyn = self.dynamic[axis] |
| 59 | self.dynamic[axis] = lambda: f(dyn()) |
| 60 | |
| 61 | def get_static(self): |
| 62 | return self.static |
| 63 | |
| 64 | @property |
| 65 | def ndims(self): |
| 66 | return len(self.static) |
| 67 | |
| 68 | def get_dynamic(self, axis=None): |
| 69 | if axis is None: |
| 70 | return [self.dynamic[k]() for k in range(self.ndims)] |
| 71 | return self.dynamic[axis]() |
| 72 | |
| 73 | |
| 74 | if __name__ == '__main__': |
no outgoing calls
no test coverage detected