| 54 | return prod(self.shape) |
| 55 | |
| 56 | def view(self, *shape): |
| 57 | assert set(map(type, shape)) == {int} |
| 58 | n_unknown = len(tuple(filter(lambda l: l < 0, shape))) |
| 59 | new_shape = list(shape) |
| 60 | if n_unknown == 0: |
| 61 | assert prod(shape) == self.numel() |
| 62 | elif n_unknown == 1: |
| 63 | n_known_elements = prod(filter(lambda l: l >= 0, shape)) |
| 64 | for i, l in enumerate(new_shape): |
| 65 | if l == -1: |
| 66 | assert self.numel() % n_known_elements == 0 |
| 67 | new_shape[i] = self.numel() // n_known_elements |
| 68 | break |
| 69 | else: |
| 70 | raise ValueError('More than one dimensions need to be inferred!') |
| 71 | return TensorInfo(self.name, self.dtype, tuple(new_shape)) |
| 72 | |
| 73 | def __len__(self): |
| 74 | return self.shape[0] |