MCPcopy Index your code
hub / github.com/apache/tvm / test_adam_simple

Function test_adam_simple

tests/python/relax/test_training_optimizer.py:301–396  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

299
300
301def test_adam_simple():
302 x = relax.Var("x", R.Tensor((3, 3), "float32"))
303 y = relax.Var("y", R.Tensor((3,), "float32"))
304 adam = Adam(0.01).init([x, y]).get_function()
305
306 @R.function
307 def adam_expected(
308 params: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
309 gradients: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
310 optim_states: R.Tuple(
311 R.Tensor((), "int64"),
312 R.Tensor((), "float32"),
313 R.Tensor((), "float32"),
314 R.Tensor((3, 3), "float32"),
315 R.Tensor((3,), "float32"),
316 R.Tensor((3, 3), "float32"),
317 R.Tensor((3,), "float32"),
318 ),
319 ) -> R.Tuple(
320 R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
321 R.Tuple(
322 R.Tensor((), "int64"),
323 R.Tensor((), "float32"),
324 R.Tensor((), "float32"),
325 R.Tensor((3, 3), "float32"),
326 R.Tensor((3,), "float32"),
327 R.Tensor((3, 3), "float32"),
328 R.Tensor((3,), "float32"),
329 ),
330 ):
331 R.func_attr({"global_symbol": "Adam"})
332 # block 0
333 with R.dataflow():
334 num_steps: R.Tensor((), "int64") = optim_states[0]
335 num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64"))
336 lv: R.Tensor((), "float32") = optim_states[1]
337 beta1_prod: R.Tensor((), "float32") = R.multiply(lv, R.const(0.9, "float32"))
338 lv1: R.Tensor((), "float32") = optim_states[2]
339 beta2_prod: R.Tensor((), "float32") = R.multiply(lv1, R.const(0.999, "float32"))
340 x: R.Tensor((3, 3), "float32") = params[0]
341 x_grad: R.Tensor((3, 3), "float32") = gradients[0]
342 x_m: R.Tensor((3, 3), "float32") = optim_states[3]
343 x_v: R.Tensor((3, 3), "float32") = optim_states[5]
344 lv2: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.9, "float32"), x_m)
345 lv3: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.1, "float32"), x_grad)
346 x_m_new: R.Tensor((3, 3), "float32") = R.add(lv2, lv3)
347 lv4: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.999, "float32"), x_v)
348 lv5: R.Tensor((3, 3), "float32") = R.multiply(x_grad, x_grad)
349 lv6: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.001, "float32"), lv5)
350 x_v_new: R.Tensor((3, 3), "float32") = R.add(lv4, lv6)
351 lv7: R.Tensor((), "float32") = R.subtract(R.const(1, "float32"), beta1_prod)
352 x_m_hat: R.Tensor((3, 3), "float32") = R.divide(x_m_new, lv7)
353 lv8: R.Tensor((), "float32") = R.subtract(R.const(1, "float32"), beta2_prod)
354 x_v_hat: R.Tensor((3, 3), "float32") = R.divide(x_v_new, lv8)
355 lv9: R.Tensor((3, 3), "float32") = R.sqrt(x_v_hat)
356 lv10: R.Tensor((3, 3), "float32") = R.add(lv9, R.const(1e-08, "float32"))
357 lv11: R.Tensor((3, 3), "float32") = R.divide(x_m_hat, lv10)
358 lv12: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.01, "float32"), lv11)

Callers

nothing calls this directly

Calls 5

AdamClass · 0.90
assert_structural_equalFunction · 0.90
TensorMethod · 0.80
get_functionMethod · 0.45
initMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…