Use blockbuilder to construct an optimizer function that executes updates of the parameters and the optimizer state. `init()` should be called before `get_function()`. Returns ------- func : Function The optimizer function.
(self)
| 450 | return self |
| 451 | |
| 452 | def get_function(self) -> Function: |
| 453 | """Use blockbuilder to construct an optimizer function that executes updates of the |
| 454 | parameters and the optimizer state. `init()` should be called before `get_function()`. |
| 455 | |
| 456 | Returns |
| 457 | ------- |
| 458 | func : Function |
| 459 | The optimizer function. |
| 460 | """ |
| 461 | self._check_init() |
| 462 | plist = self.param_list |
| 463 | len_param = len(plist) |
| 464 | dtype = self.dtype |
| 465 | |
| 466 | # input variables |
| 467 | param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) |
| 468 | grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) |
| 469 | state_var = Var( |
| 470 | "optim_states", |
| 471 | TupleStructInfo([TensorStructInfo((), "int64"), *(p.struct_info for p in plist)]), |
| 472 | ) |
| 473 | |
| 474 | # constants |
| 475 | lr = const(self.lr, dtype) |
| 476 | momentum = const(self.momentum, dtype) |
| 477 | weight_decay = const(self.weight_decay, dtype) |
| 478 | dampening_inv = const(_high_precision_subtract(1, self.dampening), dtype) |
| 479 | one = const(1, "int64") |
| 480 | |
| 481 | builder = BlockBuilder() |
| 482 | with builder.function(self.name, [param_var, grad_var, state_var]): |
| 483 | with builder.dataflow(): |
| 484 | param_list_new, state_list_new = [], [] |
| 485 | |
| 486 | # handle num_steps |
| 487 | num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps") |
| 488 | num_steps_new = builder.emit(add(num_steps, one), "num_steps_new") |
| 489 | state_list_new.append(num_steps_new) |
| 490 | |
| 491 | # computation logics |
| 492 | for i in range(len_param): |
| 493 | name = self.param_list[i].name_hint |
| 494 | p = builder.emit(TupleGetItem(param_var, i), name) |
| 495 | g = builder.emit(TupleGetItem(grad_var, i), name + "_grad") |
| 496 | v = builder.emit(TupleGetItem(state_var, i + 1), name + "_v") |
| 497 | if self.weight_decay: |
| 498 | g = builder.emit(add(multiply(weight_decay, p), g), name + "_grad_new") |
| 499 | damp_g = multiply(dampening_inv, g) if self.dampening else g |
| 500 | v_new = builder.emit(add(multiply(momentum, v), damp_g), name + "_v_new") |
| 501 | g_new = ( |
| 502 | builder.emit(add(g, multiply(momentum, v_new)), name + "_g_nest") |
| 503 | if self.nesterov |
| 504 | else v_new |
| 505 | ) |
| 506 | p_new = builder.emit(subtract(p, multiply(lr, g_new)), name + "_new") |
| 507 | param_list_new.append(p_new) |
| 508 | state_list_new.append(v_new) |
| 509 |
nothing calls this directly
no test coverage detected