| 1474 | |
| 1475 | @derived_from(np.linalg) |
| 1476 | def norm(x, ord=None, axis=None, keepdims=False): |
| 1477 | if axis is None: |
| 1478 | axis = tuple(range(x.ndim)) |
| 1479 | elif isinstance(axis, Number): |
| 1480 | axis = (int(axis),) |
| 1481 | else: |
| 1482 | axis = tuple(axis) |
| 1483 | |
| 1484 | if len(axis) > 2: |
| 1485 | raise ValueError("Improper number of dimensions to norm.") |
| 1486 | |
| 1487 | if ord == "fro": |
| 1488 | ord = None |
| 1489 | if len(axis) == 1: |
| 1490 | raise ValueError("Invalid norm order for vectors.") |
| 1491 | |
| 1492 | r = abs(x) |
| 1493 | |
| 1494 | if ord is None: |
| 1495 | r = (r**2).sum(axis=axis, keepdims=keepdims) ** 0.5 |
| 1496 | elif ord == "nuc": |
| 1497 | if len(axis) == 1: |
| 1498 | raise ValueError("Invalid norm order for vectors.") |
| 1499 | if x.ndim > 2: |
| 1500 | raise NotImplementedError("SVD based norm not implemented for ndim > 2") |
| 1501 | |
| 1502 | r = svd(x)[1][None].sum(keepdims=keepdims) |
| 1503 | elif ord == np.inf: |
| 1504 | if len(axis) == 1: |
| 1505 | r = r.max(axis=axis, keepdims=keepdims) |
| 1506 | else: |
| 1507 | r = r.sum(axis=axis[1], keepdims=True).max(axis=axis[0], keepdims=True) |
| 1508 | if keepdims is False: |
| 1509 | r = r.squeeze(axis=axis) |
| 1510 | elif ord == -np.inf: |
| 1511 | if len(axis) == 1: |
| 1512 | r = r.min(axis=axis, keepdims=keepdims) |
| 1513 | else: |
| 1514 | r = r.sum(axis=axis[1], keepdims=True).min(axis=axis[0], keepdims=True) |
| 1515 | if keepdims is False: |
| 1516 | r = r.squeeze(axis=axis) |
| 1517 | elif ord == 0: |
| 1518 | if len(axis) == 2: |
| 1519 | raise ValueError("Invalid norm order for matrices.") |
| 1520 | |
| 1521 | r = (r != 0).astype(r.dtype).sum(axis=axis, keepdims=keepdims) |
| 1522 | elif ord == 1: |
| 1523 | if len(axis) == 1: |
| 1524 | r = r.sum(axis=axis, keepdims=keepdims) |
| 1525 | else: |
| 1526 | r = r.sum(axis=axis[0], keepdims=True).max(axis=axis[1], keepdims=True) |
| 1527 | if keepdims is False: |
| 1528 | r = r.squeeze(axis=axis) |
| 1529 | elif len(axis) == 2 and ord == -1: |
| 1530 | r = r.sum(axis=axis[0], keepdims=True).min(axis=axis[1], keepdims=True) |
| 1531 | if keepdims is False: |
| 1532 | r = r.squeeze(axis=axis) |
| 1533 | elif len(axis) == 2 and ord == 2: |