| 155 | ]) |
| 156 | |
| 157 | def finalize_after(ctx:AllocCtx, x:UOp): |
| 158 | # untagged: record as an assign for the call body |
| 159 | if x.tag is None: |
| 160 | ctx.assigns.append(x) |
| 161 | return None |
| 162 | # tagged: untag and map each original pre-rewrite UOp to the stripped buffer; the untagged result is reprocessed as untagged |
| 163 | ret = x.replace(tag=None) |
| 164 | replace_uop = ret |
| 165 | while replace_uop.op is Ops.AFTER: replace_uop = replace_uop.src[0] |
| 166 | for t in x.tag: |
| 167 | original_uop: UOp = ctx.uop_list[t] |
| 168 | ctx.buffer_map[original_uop] = replace_uop.shrink_to(original_uop.shape) |
| 169 | return ret |
| 170 | |
| 171 | def replace_input_buffer(ctx:AllocCtx, b:UOp): |
| 172 | ctx.replacements.append(b) |