Return a reshaped array. Only supported for contiguous arrays. Args: shape : An int or tuple of ints specifying the shape of the returned array.
(self, shape)
| 4267 | return a |
| 4268 | |
| 4269 | def reshape(self, shape): |
| 4270 | """Return a reshaped array. Only supported for contiguous arrays. |
| 4271 | |
| 4272 | Args: |
| 4273 | shape : An int or tuple of ints specifying the shape of the returned array. |
| 4274 | """ |
| 4275 | if not self.is_contiguous: |
| 4276 | raise RuntimeError("Reshaping non-contiguous arrays is unsupported.") |
| 4277 | |
| 4278 | # convert shape to tuple |
| 4279 | if shape is None: |
| 4280 | raise RuntimeError("shape parameter is required.") |
| 4281 | if isinstance(shape, int): |
| 4282 | shape = (shape,) |
| 4283 | elif not isinstance(shape, tuple): |
| 4284 | shape = tuple(shape) |
| 4285 | |
| 4286 | if len(shape) > ARRAY_MAX_DIMS: |
| 4287 | raise RuntimeError( |
| 4288 | f"Arrays may only have {ARRAY_MAX_DIMS} dimensions maximum, trying to create array with {len(shape)} dims." |
| 4289 | ) |
| 4290 | |
| 4291 | # check for -1 dimension and reformat |
| 4292 | if -1 in shape: |
| 4293 | idx = self.size |
| 4294 | denom = 1 |
| 4295 | minus_one_count = 0 |
| 4296 | for i, d in enumerate(shape): |
| 4297 | if d == -1: |
| 4298 | idx = i |
| 4299 | minus_one_count += 1 |
| 4300 | else: |
| 4301 | denom *= d |
| 4302 | if minus_one_count > 1: |
| 4303 | raise RuntimeError("Cannot infer shape if more than one index is -1.") |
| 4304 | new_shape = list(shape) |
| 4305 | new_shape[idx] = int(self.size / denom) |
| 4306 | shape = tuple(new_shape) |
| 4307 | |
| 4308 | size = 1 |
| 4309 | for d in shape: |
| 4310 | size *= d |
| 4311 | |
| 4312 | if size != self.size: |
| 4313 | raise RuntimeError("Reshaped array must have the same total size as the original.") |
| 4314 | |
| 4315 | a = array( |
| 4316 | ptr=self.ptr, |
| 4317 | dtype=self.dtype, |
| 4318 | shape=shape, |
| 4319 | strides=None, |
| 4320 | device=self.device, |
| 4321 | pinned=self.pinned, |
| 4322 | copy=False, |
| 4323 | grad=None if self.grad is None else self.grad.reshape(shape), |
| 4324 | ) |
| 4325 | |
| 4326 | # transfer read flag |