()
| 299 | |
| 300 | |
| 301 | def test_adam_simple(): |
| 302 | x = relax.Var("x", R.Tensor((3, 3), "float32")) |
| 303 | y = relax.Var("y", R.Tensor((3,), "float32")) |
| 304 | adam = Adam(0.01).init([x, y]).get_function() |
| 305 | |
| 306 | @R.function |
| 307 | def adam_expected( |
| 308 | params: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 309 | gradients: R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 310 | optim_states: R.Tuple( |
| 311 | R.Tensor((), "int64"), |
| 312 | R.Tensor((), "float32"), |
| 313 | R.Tensor((), "float32"), |
| 314 | R.Tensor((3, 3), "float32"), |
| 315 | R.Tensor((3,), "float32"), |
| 316 | R.Tensor((3, 3), "float32"), |
| 317 | R.Tensor((3,), "float32"), |
| 318 | ), |
| 319 | ) -> R.Tuple( |
| 320 | R.Tuple(R.Tensor((3, 3), "float32"), R.Tensor((3,), "float32")), |
| 321 | R.Tuple( |
| 322 | R.Tensor((), "int64"), |
| 323 | R.Tensor((), "float32"), |
| 324 | R.Tensor((), "float32"), |
| 325 | R.Tensor((3, 3), "float32"), |
| 326 | R.Tensor((3,), "float32"), |
| 327 | R.Tensor((3, 3), "float32"), |
| 328 | R.Tensor((3,), "float32"), |
| 329 | ), |
| 330 | ): |
| 331 | R.func_attr({"global_symbol": "Adam"}) |
| 332 | # block 0 |
| 333 | with R.dataflow(): |
| 334 | num_steps: R.Tensor((), "int64") = optim_states[0] |
| 335 | num_steps_new: R.Tensor((), "int64") = R.add(num_steps, R.const(1, "int64")) |
| 336 | lv: R.Tensor((), "float32") = optim_states[1] |
| 337 | beta1_prod: R.Tensor((), "float32") = R.multiply(lv, R.const(0.9, "float32")) |
| 338 | lv1: R.Tensor((), "float32") = optim_states[2] |
| 339 | beta2_prod: R.Tensor((), "float32") = R.multiply(lv1, R.const(0.999, "float32")) |
| 340 | x: R.Tensor((3, 3), "float32") = params[0] |
| 341 | x_grad: R.Tensor((3, 3), "float32") = gradients[0] |
| 342 | x_m: R.Tensor((3, 3), "float32") = optim_states[3] |
| 343 | x_v: R.Tensor((3, 3), "float32") = optim_states[5] |
| 344 | lv2: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.9, "float32"), x_m) |
| 345 | lv3: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.1, "float32"), x_grad) |
| 346 | x_m_new: R.Tensor((3, 3), "float32") = R.add(lv2, lv3) |
| 347 | lv4: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.999, "float32"), x_v) |
| 348 | lv5: R.Tensor((3, 3), "float32") = R.multiply(x_grad, x_grad) |
| 349 | lv6: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.001, "float32"), lv5) |
| 350 | x_v_new: R.Tensor((3, 3), "float32") = R.add(lv4, lv6) |
| 351 | lv7: R.Tensor((), "float32") = R.subtract(R.const(1, "float32"), beta1_prod) |
| 352 | x_m_hat: R.Tensor((3, 3), "float32") = R.divide(x_m_new, lv7) |
| 353 | lv8: R.Tensor((), "float32") = R.subtract(R.const(1, "float32"), beta2_prod) |
| 354 | x_v_hat: R.Tensor((3, 3), "float32") = R.divide(x_v_new, lv8) |
| 355 | lv9: R.Tensor((3, 3), "float32") = R.sqrt(x_v_hat) |
| 356 | lv10: R.Tensor((3, 3), "float32") = R.add(lv9, R.const(1e-08, "float32")) |
| 357 | lv11: R.Tensor((3, 3), "float32") = R.divide(x_m_hat, lv10) |
| 358 | lv12: R.Tensor((3, 3), "float32") = R.multiply(R.const(0.01, "float32"), lv11) |
nothing calls this directly
no test coverage detected
searching dependent graphs…