| 1253 | |
| 1254 | |
| 1255 | def test_normalize_compose_layout(): |
| 1256 | def case1(): |
| 1257 | layoutA = SwizzleLayout(per_element=3, swizzle_len=3, atom_len=3) |
| 1258 | layoutB = TileLayout(S[(8, 64) : (64, 1)]) |
| 1259 | layout = ComposeLayout(layoutA, layoutB.canonicalize()) |
| 1260 | assert_structural_equal(layout.canonicalize(), layoutA) |
| 1261 | |
| 1262 | case1() |
| 1263 | |
| 1264 | def case2(): |
| 1265 | layoutA = SwizzleLayout(per_element=3, swizzle_len=3, atom_len=3) |
| 1266 | layoutB = TileLayout(S[(64, 4, 64) : (64, 4096, 1)]) |
| 1267 | layout = ComposeLayout(layoutA, layoutB.canonicalize()) |
| 1268 | assert_structural_equal(layout.canonicalize(), layout) |
| 1269 | |
| 1270 | case2() |
| 1271 | |
| 1272 | |
| 1273 | def test_normalize_trainium_layout(): |