| 1297 | return lazy_compile |
| 1298 | |
| 1299 | class PatternMatcher: |
| 1300 | def __init__(self, patterns:Sequence[tuple[UPat, Callable|tuple]], compiled=bool(getenv("UPAT_COMPILE", 1))): |
| 1301 | # if this comes from a pickle, we reconstruct the lambda functions here |
| 1302 | self.patterns:list[tuple[UPat, Callable]] = [(p,types.FunctionType(*fxn) if isinstance(fxn, tuple) else fxn) for p,fxn in patterns] |
| 1303 | # NOTE: use of DefaultDict here is very dangerous! all keys will live for the lifetime of the PatternMatcher! |
| 1304 | self.pdict: dict[Ops, list[list]] = {} |
| 1305 | # uop is required, arg is optional |
| 1306 | for p,fxn in self.patterns: |
| 1307 | assert p.op is not None |
| 1308 | entry: list = [p, None, p.early_reject] |
| 1309 | entry[1] = upat_deferred_compile(p, fxn, entry) if compiled else upat_interpret(p, fxn) |
| 1310 | for uop in p.op: self.pdict.setdefault(uop, []).append(entry) |
| 1311 | |
| 1312 | def __reduce__(self): return PatternMatcher, ([(x,deconstruct_function(fxn) if fxn.__name__ == "<lambda>" else fxn) for x,fxn in self.patterns],) |
| 1313 | |
| 1314 | @functools.cache # pylint: disable=method-cache-max-size-none |
| 1315 | def __add__(self, more:PatternMatcher) -> PatternMatcher: return PatternMatcher(self.patterns+more.patterns) |
| 1316 | |
| 1317 | def rewrite(self, uop:UOp, ctx=None): |
| 1318 | if len(pats:=self.pdict.get(uop.op, [])): |
| 1319 | if (ler:=uop.__dict__.get('_src_ops')) is None: uop.__dict__['_src_ops'] = ler = {u.op for u in uop.src} |
| 1320 | for _,match,early_reject in pats: |
| 1321 | if not early_reject.issubset(ler): continue |
| 1322 | if (ret:=match(uop, ctx)) is not None and ret is not uop: return ret |
| 1323 | return None |
| 1324 | |
| 1325 | # *** tracking pattern matcher *** |
| 1326 |
no outgoing calls
searching dependent graphs…