Use blockbuilder to construct an optimizer function that executes updates of the parameters and the optimizer state. `init()` should be called before `get_function()`. Returns ------- func : Function The optimizer function.
(self)
| 630 | return self |
| 631 | |
| 632 | def get_function(self) -> Function: |
| 633 | """Use blockbuilder to construct an optimizer function that executes updates of the |
| 634 | parameters and the optimizer state. `init()` should be called before `get_function()`. |
| 635 | |
| 636 | Returns |
| 637 | ------- |
| 638 | func : Function |
| 639 | The optimizer function. |
| 640 | """ |
| 641 | self._check_init() |
| 642 | plist = self.param_list |
| 643 | len_param = len(plist) |
| 644 | dtype = self.dtype |
| 645 | |
| 646 | # input variables |
| 647 | param_var = Var("params", TupleStructInfo([p.struct_info for p in plist])) |
| 648 | grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist])) |
| 649 | state_var = Var( |
| 650 | "optim_states", |
| 651 | TupleStructInfo( |
| 652 | [ |
| 653 | TensorStructInfo((), "int64"), |
| 654 | TensorStructInfo((), dtype), |
| 655 | TensorStructInfo((), dtype), |
| 656 | *(p.struct_info for p in plist), |
| 657 | *(p.struct_info for p in plist), |
| 658 | ] |
| 659 | ), |
| 660 | ) |
| 661 | |
| 662 | # constants |
| 663 | lr = const(self.lr, dtype) |
| 664 | beta1 = const(self.beta1, dtype) |
| 665 | beta2 = const(self.beta2, dtype) |
| 666 | beta1_inv = const(_high_precision_subtract(1, self.beta1), dtype) |
| 667 | beta2_inv = const(_high_precision_subtract(1, self.beta2), dtype) |
| 668 | eps = const(self.eps, dtype) |
| 669 | weight_decay = const(self.weight_decay, dtype) |
| 670 | one_int = const(1, "int64") |
| 671 | one_float = const(1, dtype) |
| 672 | |
| 673 | builder = BlockBuilder() |
| 674 | with builder.function(self.name, [param_var, grad_var, state_var]): |
| 675 | with builder.dataflow(): |
| 676 | param_list_new = [] |
| 677 | state_list_new = [None] * (len_param * 2 + 3) # type: List[Optional[Var]] |
| 678 | |
| 679 | # handle num_steps |
| 680 | num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps") |
| 681 | num_steps_new = builder.emit(add(num_steps, one_int), "num_steps_new") |
| 682 | state_list_new[0] = num_steps_new |
| 683 | beta1_prod = builder.emit(multiply(TupleGetItem(state_var, 1), beta1), "beta1_prod") |
| 684 | beta2_prod = builder.emit(multiply(TupleGetItem(state_var, 2), beta2), "beta2_prod") |
| 685 | state_list_new[1] = beta1_prod |
| 686 | state_list_new[2] = beta2_prod |
| 687 | |
| 688 | # computation logics |
| 689 | for i in range(len_param): |
nothing calls this directly
no test coverage detected