| 337 | |
| 338 | |
| 339 | def layout_sub_figures(data: Union[Tensor, Field], |
| 340 | row_dims: Shape, |
| 341 | col_dims: Shape, |
| 342 | animate: Shape, # do not reduce these dims, has priority |
| 343 | overlay: Shape, |
| 344 | offset_row: int, |
| 345 | offset_col: int, |
| 346 | positioning: Dict[Tuple[int, int], List] = None, |
| 347 | indices: Dict[Tuple[int, int], List[dict]] = None, |
| 348 | base_index: Dict[str, Union[int, str]] = None) -> Tuple[Dict[Tuple[int, int], List[Field]], Dict[Tuple[int, int], List[dict]]]: |
| 349 | if positioning is None: |
| 350 | assert indices is None and base_index is None |
| 351 | positioning = {} |
| 352 | indices = {} |
| 353 | base_index = {} |
| 354 | # --- if data is a group of elements, lay them out recursively --- |
| 355 | if isinstance(data, Tensor) and data.dtype.kind == object: # layout |
| 356 | if not data.shape: # nothing to plot |
| 357 | return positioning, indices |
| 358 | dim0 = data.shape[0] |
| 359 | if dim0.only(overlay): |
| 360 | for overlay_index in dim0.only(overlay).meshgrid(names=True): # overlay these fields |
| 361 | # ToDo expand constants along rows/cols |
| 362 | layout_sub_figures(data[overlay_index], row_dims, col_dims, animate, overlay, offset_row, offset_col, positioning, indices, {**base_index, **overlay_index}) |
| 363 | return positioning, indices |
| 364 | elif dim0.only(animate): |
| 365 | pass |
| 366 | else: |
| 367 | elements = math.unstack(data, dim0.name) |
| 368 | offset = 0 |
| 369 | for item_name, e in zip(dim0.get_item_names(dim0.name) or range(dim0.size), elements): |
| 370 | index = dict(base_index, **{dim0.name: item_name}) |
| 371 | if dim0.only(row_dims): |
| 372 | layout_sub_figures(e, row_dims, col_dims, animate, overlay, offset_row + offset, offset_col, positioning, indices, index) |
| 373 | offset += shape(e).only(row_dims).volume |
| 374 | elif dim0.only(col_dims): |
| 375 | layout_sub_figures(e, row_dims, col_dims, animate, overlay, offset_row, offset_col + offset, positioning, indices, index) |
| 376 | offset += shape(e).only(col_dims).volume |
| 377 | else: |
| 378 | layout_sub_figures(e, row_dims, col_dims, animate, overlay, offset_row, offset_col, positioning, indices, index) |
| 379 | return positioning, indices |
| 380 | # --- data must be a plottable object --- |
| 381 | data = to_field(data) |
| 382 | overlay = data.shape.only(overlay) |
| 383 | animate = data.shape.only(animate).without(overlay) |
| 384 | row_shape = data.shape.only(row_dims).without(animate).without(overlay) |
| 385 | col_shape = data.shape.only(col_dims).without(row_dims).without(animate).without(overlay) |
| 386 | row_shape &= row_dims.after_gather(base_index) |
| 387 | col_shape &= col_dims.after_gather(base_index) |
| 388 | for ri, r in enumerate(row_shape.meshgrid(names=True)): |
| 389 | for ci, c in enumerate(col_shape.meshgrid(names=True)): |
| 390 | for o in overlay.meshgrid(names=True): |
| 391 | sub_data = data[r][c][o] |
| 392 | positioning.setdefault((offset_row + ri, offset_col + ci), []).append(sub_data) |
| 393 | indices.setdefault((offset_row + ri, offset_col + ci), []).append(dict(base_index, **r, **c, **o)) |
| 394 | return positioning, indices |
| 395 | |
| 396 | |