| 44 | |
| 45 | @dataclass |
| 46 | class TensorInfo: |
| 47 | name: str |
| 48 | dtype: trt.DataType |
| 49 | shape: tuple |
| 50 | |
| 51 | # add more info like strides, formats if needed |
| 52 | |
| 53 | def numel(self): |
| 54 | return prod(self.shape) |
| 55 | |
| 56 | def view(self, *shape): |
| 57 | assert set(map(type, shape)) == {int} |
| 58 | n_unknown = len(tuple(filter(lambda l: l < 0, shape))) |
| 59 | new_shape = list(shape) |
| 60 | if n_unknown == 0: |
| 61 | assert prod(shape) == self.numel() |
| 62 | elif n_unknown == 1: |
| 63 | n_known_elements = prod(filter(lambda l: l >= 0, shape)) |
| 64 | for i, l in enumerate(new_shape): |
| 65 | if l == -1: |
| 66 | assert self.numel() % n_known_elements == 0 |
| 67 | new_shape[i] = self.numel() // n_known_elements |
| 68 | break |
| 69 | else: |
| 70 | raise ValueError('More than one dimensions need to be inferred!') |
| 71 | return TensorInfo(self.name, self.dtype, tuple(new_shape)) |
| 72 | |
| 73 | def __len__(self): |
| 74 | return self.shape[0] |
| 75 | |
| 76 | def squeeze(self, dim=0): |
| 77 | if self.shape[dim] != 1: |
| 78 | raise ValueError(f"dim {dim} is {self.shape[dim]} instead of 1!") |
| 79 | return TensorInfo(self.name, self.dtype, |
| 80 | self.shape[:dim] + self.shape[dim + 1:]) |
| 81 | |
| 82 | |
| 83 | class Session(object): |
no outgoing calls
no test coverage detected