()
| 397 | |
| 398 | |
| 399 | def test_adam_complex(): |
| 400 | x = relax.Var("x", R.Tensor((3, 3), "float32")) |
| 401 | y = relax.Var("y", R.Tensor((3,), "float32")) |
| 402 | adam = Adam(0.01, (0.8, 0.85), 1e-7, 0.1).init([x, y]).get_function() |
| 403 | |
| 404 | @R.function |
| 405 | def adam_expected( |
| 406 | params: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 407 | gradients: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 408 | optim_states: R.Tuple( |
| 409 | R.Tensor((), "int64"), |
| 410 | R.Tensor((), "float32"), |
| 411 | R.Tensor((), "float32"), |
| 412 | R.Tensor((3, 3), "float32"), |
| 413 | R.Tensor((3,), "float32"), |
| 414 | R.Tensor((3, 3), "float32"), |
| 415 | R.Tensor((3,), "float32"), |
| 416 | ), |
| 417 | ) -> R.Tuple( |
| 418 | R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 419 | R.Tuple( |
| 420 | R.Tensor((), "int64"), |
| 421 | R.Tensor((), "float32"), |
| 422 | R.Tensor((), "float32"), |
| 423 | R.Tensor((3, 3), "float32"), |
| 424 | R.Tensor((3,), "float32"), |
| 425 | R.Tensor((3, 3), "float32"), |
| 426 | R.Tensor((3,), "float32"), |
| 427 | ), |
| 428 | ): |
| 429 | R.func_attr({"global_symbol": "Adam"}) |
| 430 | # block 0 |
| 431 | with R.dataflow(): |
| 432 | num_steps: R.Tensor((), "int64") = optim_states[0] |
| 433 | num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64")) |
| 434 | lv: R.Tensor((), "float32") = optim_states[1] |
| 435 | beta1_prod: R.Tensor((), "float32") = R.multiply(lv, R.const(0.8, "float32")) |
| 436 | lv1: R.Tensor((), "float32") = optim_states[2] |
| 437 | beta2_prod: R.Tensor((), "float32") = R.multiply(lv1, R.const(0.85, "float32")) |
| 438 | x: R.Tensor((3, 3), "float32") = params[0] |
| 439 | x_grad: R.Tensor((3, 3), "float32") = gradients[0] |
| 440 | x_m: R.Tensor((3, 3), "float32") = optim_states[3] |
| 441 | x_v: R.Tensor((3, 3), "float32") = optim_states[5] |
| 442 | lv2: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.1, "float32"), x) |
| 443 | x_grad_new: R.Tensor((3, 3), "float32") = R.add(lv2, x_grad) |
| 444 | lv3: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.8, "float32"), x_m) |
| 445 | lv4: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.2, "float32"), x_grad_new) |
| 446 | x_m_new: R.Tensor((3, 3), "float32") = R.add(lv3, lv4) |
| 447 | lv5: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.85, "float32"), x_v) |
| 448 | lv6: R.Tensor((3, 3), "float32") = R.multiply(x_grad_new, x_grad_new) |
| 449 | lv7: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.15, "float32"), lv6) |
| 450 | x_v_new: R.Tensor((3, 3), "float32") = R.add(lv5, lv7) |
| 451 | lv8: R.Tensor((), "float32") = R.subtract(R.const(1, "float32"), beta1_prod) |
| 452 | x_m_hat: R.Tensor((3, 3), "float32") = R.divide(x_m_new, lv8) |
| 453 | lv9: R.Tensor((), "float32") = R.subtract(R.const(1, "float32"), beta2_prod) |
| 454 | x_v_hat: R.Tensor((3, 3), "float32") = R.divide(x_v_new, lv9) |
| 455 | lv10: R.Tensor((3, 3), "float32") = R.sqrt(x_v_hat) |
| 456 | lv11: R.Tensor((3, 3), "float32") = R.add(lv10, R.const(1e-07, "float32")) |
nothing calls this directly
no test coverage detected
searching dependent graphs…