MCPcopy Index your code
hub / github.com/apache/tvm / test_adam_complex

Function test_adam_complex

tests/python/relax/test_training_optimizer.py:399–498  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

397
398
399def test_adam_complex():
400 x = relax.Var("x", R.Tensor((3, 3), "float32"))
401 y = relax.Var("y", R.Tensor((3,), "float32"))
402 adam = Adam(0.01, (0.8, 0.85), 1e-7, 0.1).init([x, y]).get_function()
403
404 @R.function
405 def adam_expected(
406 params: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
407 gradients: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
408 optim_states: R.Tuple(
409 R.Tensor((), "int64"),
410 R.Tensor((), "float32"),
411 R.Tensor((), "float32"),
412 R.Tensor((3, 3), "float32"),
413 R.Tensor((3,), "float32"),
414 R.Tensor((3, 3), "float32"),
415 R.Tensor((3,), "float32"),
416 ),
417 ) -> R.Tuple(
418 R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")),
419 R.Tuple(
420 R.Tensor((), "int64"),
421 R.Tensor((), "float32"),
422 R.Tensor((), "float32"),
423 R.Tensor((3, 3), "float32"),
424 R.Tensor((3,), "float32"),
425 R.Tensor((3, 3), "float32"),
426 R.Tensor((3,), "float32"),
427 ),
428 ):
429 R.func_attr({"global_symbol": "Adam"})
430 # block 0
431 with R.dataflow():
432 num_steps: R.Tensor((), "int64") = optim_states[0]
433 num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64"))
434 lv: R.Tensor((), "float32") = optim_states[1]
435 beta1_prod: R.Tensor((), "float32") = R.multiply(lv, R.const(0.8, "float32"))
436 lv1: R.Tensor((), "float32") = optim_states[2]
437 beta2_prod: R.Tensor((), "float32") = R.multiply(lv1, R.const(0.85, "float32"))
438 x: R.Tensor((3, 3), "float32") = params[0]
439 x_grad: R.Tensor((3, 3), "float32") = gradients[0]
440 x_m: R.Tensor((3, 3), "float32") = optim_states[3]
441 x_v: R.Tensor((3, 3), "float32") = optim_states[5]
442 lv2: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.1, "float32"), x)
443 x_grad_new: R.Tensor((3, 3), "float32") = R.add(lv2, x_grad)
444 lv3: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.8, "float32"), x_m)
445 lv4: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.2, "float32"), x_grad_new)
446 x_m_new: R.Tensor((3, 3), "float32") = R.add(lv3, lv4)
447 lv5: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.85, "float32"), x_v)
448 lv6: R.Tensor((3, 3), "float32") = R.multiply(x_grad_new, x_grad_new)
449 lv7: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.15, "float32"), lv6)
450 x_v_new: R.Tensor((3, 3), "float32") = R.add(lv5, lv7)
451 lv8: R.Tensor((), "float32") = R.subtract(R.const(1, "float32"), beta1_prod)
452 x_m_hat: R.Tensor((3, 3), "float32") = R.divide(x_m_new, lv8)
453 lv9: R.Tensor((), "float32") = R.subtract(R.const(1, "float32"), beta2_prod)
454 x_v_hat: R.Tensor((3, 3), "float32") = R.divide(x_v_new, lv9)
455 lv10: R.Tensor((3, 3), "float32") = R.sqrt(x_v_hat)
456 lv11: R.Tensor((3, 3), "float32") = R.add(lv10, R.const(1e-07, "float32"))

Callers

nothing calls this directly

Calls 5

AdamClass · 0.90
assert_structural_equalFunction · 0.90
TensorMethod · 0.80
get_functionMethod · 0.45
initMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…