MCPcopy Index your code
hub / github.com/tinygrad/tinygrad / __call__

Method __call__

tinygrad/function.py:37–93  ·  view source on GitHub ↗
(self, *args, **kwargs)

Source from the content-addressed store, hash-verified

35 def __get__(self, obj, objtype=None): return functools.partial(self.__call__, obj) if obj is not None else self
36
37 def __call__(self, *args, **kwargs) -> ReturnType:
38 st = time.perf_counter()
39
40 params = get_state_dict((args, kwargs), tensor_type=(Tensor, UOp)).values()
41
42 # deduplicate input_uops, keeping the first occurrence index for each unique uop
43 call_uops: list[UOp] = dedup([(t.uop if isinstance(t, Tensor) else t) for t in params])
44
45 # disable realize/schedule while this is running
46 # run it and do surgery later
47 with Context(ALLOW_DEVICE_USAGE=getenv("DEVICE_IN_FUNCTION_BUG", 0)):
48 _function.depth += 1
49 ret = self.fxn(*args, **kwargs)
50 _function.depth -= 1
51 if isinstance(ret, Tensor):
52 uret = ret.uop
53 elif isinstance(ret, tuple) and all(isinstance(x, Tensor) for x in ret):
54 uret = UOp.maketuple(*[x.uop for x in ret])
55 else:
56 raise RuntimeError(f"function return type {type(ret)} not supported")
57
58 # replace the known inputs with params (using deduplicated slots)
59 subs = {}
60 for i,x in enumerate(call_uops): subs[x] = x.param_like(i)
61 uret = uret.substitute(subs)
62
63 # add contiguous to call_uops
64 #call_uops = [x.contiguous() for x in call_uops]
65
66 # the BUFFERs that are left are the implicit inputs
67 num_explicit = len(call_uops)
68 uret = graph_rewrite(uret, pm_ctx, (call_uops, itertools.count(0)), bottom_up=True, name="get_implicit_inputs")
69 name = getattr(self.fxn, '__qualname__', None) or type(self.fxn).__qualname__
70 if not self.allow_implicit:
71 implicit_buffers = [x for x in call_uops[num_explicit:] if x.op is Ops.BUFFER]
72 if implicit_buffers:
73 buf_strs = '\n '.join(f"{i}: dtype={b.dtype}, size={b.arg}, device={b.device}" for i,b in enumerate(implicit_buffers))
74 raise RuntimeError(f"function {name} has {len(implicit_buffers)} implicit buffer(s), but allow_implicit=False\n {buf_strs}")
75
76 # assign output
77 #pbuffer = uret.param_like(len(call_uops))
78 #assigned = pbuffer.assign(uret).sink()
79 #buffer = UOp.new_buffer(pbuffer.device, pbuffer.size, pbuffer.dtype).reshape(uret.shape)
80 #call = assigned.call(*call_uops, buffer, name=name)
81 #ret = buffer.after(call)
82
83 fret = uret.call(*call_uops, grad_fxn=self.grad_fxn, name=name, precompile=self.precompile,
84 precompile_backward=self.precompile_backward)
85
86 if DEBUG >= 2:
87 #signature = [(x._shape, x.dtype, x.device) for x in call_uops]
88 print(" "*_function.depth+f"function {uret.key.hex()[:8]} in {(time.perf_counter()-st)*1000:8.2f} ms: {name}") # with sig {signature}")
89
90 if isinstance(ret, tuple):
91 return cast(ReturnType, tuple(Tensor(fret.gettuple(i)) for i in range(len(ret))))
92 else:
93 return cast(ReturnType, Tensor(fret.gettuple(0)))
94

Callers

nothing calls this directly

Calls 14

get_state_dictFunction · 0.90
dedupFunction · 0.90
ContextClass · 0.90
getenvFunction · 0.90
graph_rewriteFunction · 0.90
TensorClass · 0.90
castFunction · 0.85
maketupleMethod · 0.80
param_likeMethod · 0.80
substituteMethod · 0.80
gettupleMethod · 0.80
fxnMethod · 0.45

Tested by

no test coverage detected