| 1271 | |
| 1272 | |
| 1273 | def test_normalize_trainium_layout(): |
| 1274 | def case1(): |
| 1275 | layout = TileLayout(S[(8, 8) : (8 @ P, 1 @ F)]) |
| 1276 | assert_structural_equal(layout, layout.canonicalize()) |
| 1277 | |
| 1278 | case1() |
| 1279 | |
| 1280 | def case2(): |
| 1281 | layout = TileLayout(S[(8, 1, 8) : (8 @ F, 1 @ P, 1 @ F)]) |
| 1282 | layout_expected = TileLayout(S[64 : 1 @ F]) |
| 1283 | assert_structural_equal(layout_expected, layout.canonicalize()) |
| 1284 | |
| 1285 | case2() |
| 1286 | |
| 1287 | def case3(): |
| 1288 | layout = TileLayout(S[(8, 8, 8) : (8 @ F, 1 @ P, 1 @ F)]) |
| 1289 | assert_structural_equal(layout, layout.canonicalize()) |
| 1290 | |
| 1291 | case3() |
| 1292 | |
| 1293 | |
| 1294 | def test_direct_sum(): |