MCPcopy
hub / github.com/tinygrad/tinygrad / train_one_step

Function train_one_step

test/models/test_train.py:17–27  ·  view source on GitHub ↗
(model,X,Y)

Source from the content-addressed store, hash-verified

15BS = getenv("BS", 2)
16
17def train_one_step(model,X,Y):
18 params = get_parameters(model)
19 pcount = 0
20 for p in params:
21 pcount += np.prod(p.shape)
22 optimizer = optim.SGD(params, lr=0.001)
23 print("stepping %r with %.1fM params bs %d" % (type(model), pcount/1e6, BS))
24 st = time.time()
25 train(model, X, Y, optimizer, steps=1, BS=BS)
26 et = time.time()-st
27 print("done in %.2f ms" % (et*1000.))
28
29def check_gc():
30 if Device.DEFAULT == "CL":

Callers 5

test_convnextMethod · 0.85
test_efficientnetMethod · 0.85
test_vitMethod · 0.85
test_transformerMethod · 0.85
test_resnetMethod · 0.85

Calls 3

get_parametersFunction · 0.90
trainFunction · 0.90
prodMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…