(self, *args, **kwargs)
| 35 | def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self |
| 36 | |
| 37 | def __call__(self, *args, **kwargs) -> ReturnType: |
| 38 | st = time.perf_counter() |
| 39 | |
| 40 | params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values() |
| 41 | |
| 42 | # deduplicate input_uops, keeping the first occurrence index for each unique uop |
| 43 | call_uops: list[UOp] = dedup([(t.uop if isinstance(t, Tensor) else t) for t in params]) |
| 44 | |
| 45 | # disable realize/schedule while this is running |
| 46 | # run it and do surgery later |
| 47 | with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)): |
| 48 | _function.depth += 1 |
| 49 | ret = self.fxn(*args, **kwargs) |
| 50 | _function.depth -= 1 |
| 51 | if isinstance(ret, Tensor): |
| 52 | uret = ret.uop |
| 53 | elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret): |
| 54 | uret = UOp.maketuple(*[x.uop for x in ret]) |
| 55 | else: |
| 56 | raise RuntimeError(f"function return type {type(ret)} not supported") |
| 57 | |
| 58 | # replace the known inputs with params (using deduplicated slots) |
| 59 | subs = {} |
| 60 | for i,x in enumerate(call_uops): subs[x] = x.param_like(i) |
| 61 | uret = uret.substitute(subs) |
| 62 | |
| 63 | # add contiguous to call_uops |
| 64 | #call_uops = [x.contiguous() for x in call_uops] |
| 65 | |
| 66 | # the BUFFERs that are left are the implicit inputs |
| 67 | num_explicit = len(call_uops) |
| 68 | uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs") |
| 69 | name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__ |
| 70 | if not self.allow_implicit: |
| 71 | implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER] |
| 72 | if implicit_buffers: |
| 73 | buf_strs = '\n '.join(f"{i}: dtype={b.dtype}, size={b.arg}, device={b.device}" for i,b in enumerate(implicit_buffers)) |
| 74 | raise RuntimeError(f"function {name} has {len(implicit_buffers)} implicit buffer(s), but allow_implicit=False\n {buf_strs}") |
| 75 | |
| 76 | # assign output |
| 77 | #pbuffer = uret.param_like(len(call_uops)) |
| 78 | #assigned = pbuffer.assign(uret).sink() |
| 79 | #buffer = UOp.new_buffer(pbuffer.device, pbuffer.size, pbuffer.dtype).reshape(uret.shape) |
| 80 | #call = assigned.call(*call_uops, buffer, name=name) |
| 81 | #ret = buffer.after(call) |
| 82 | |
| 83 | fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile, |
| 84 | precompile_backward=self.precompile_backward) |
| 85 | |
| 86 | if DEBUG >= 2: |
| 87 | #signature = [(x._shape, x.dtype, x.device) for x in call_uops] |
| 88 | print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}") # with sig {signature}") |
| 89 | |
| 90 | if isinstance(ret, tuple): |
| 91 | return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret)))) |
| 92 | else: |
| 93 | return cast(ReturnType, Tensor(fret.gettuple(0))) |
| 94 |
nothing calls this directly
no test coverage detected