MCPcopy Index your code
hub / github.com/apache/tvm / get_function

Method get_function

python/tvm/relax/training/optimizer.py:632–722  ·  view source on GitHub ↗

Use blockbuilder to construct an optimizer function that executes updates of the parameters and the optimizer state. `init()` should be called before `get_function()`. Returns ------- func : Function The optimizer function.

(self)

Source from the content-addressed store, hash-verified

630 return self
631
632 def get_function(self) -> Function:
633 """Use blockbuilder to construct an optimizer function that executes updates of the
634 parameters and the optimizer state. `init()` should be called before `get_function()`.
635
636 Returns
637 -------
638 func : Function
639 The optimizer function.
640 """
641 self._check_init()
642 plist = self.param_list
643 len_param = len(plist)
644 dtype = self.dtype
645
646 # input variables
647 param_var = Var("params", TupleStructInfo([p.struct_info for p in plist]))
648 grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist]))
649 state_var = Var(
650 "optim_states",
651 TupleStructInfo(
652 [
653 TensorStructInfo((), "int64"),
654 TensorStructInfo((), dtype),
655 TensorStructInfo((), dtype),
656 *(p.struct_info for p in plist),
657 *(p.struct_info for p in plist),
658 ]
659 ),
660 )
661
662 # constants
663 lr = const(self.lr, dtype)
664 beta1 = const(self.beta1, dtype)
665 beta2 = const(self.beta2, dtype)
666 beta1_inv = const(_high_precision_subtract(1, self.beta1), dtype)
667 beta2_inv = const(_high_precision_subtract(1, self.beta2), dtype)
668 eps = const(self.eps, dtype)
669 weight_decay = const(self.weight_decay, dtype)
670 one_int = const(1, "int64")
671 one_float = const(1, dtype)
672
673 builder = BlockBuilder()
674 with builder.function(self.name, [param_var, grad_var, state_var]):
675 with builder.dataflow():
676 param_list_new = []
677 state_list_new = [None] * (len_param * 2 + 3) # type: List[Optional[Var]]
678
679 # handle num_steps
680 num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps")
681 num_steps_new = builder.emit(add(num_steps, one_int), "num_steps_new")
682 state_list_new[0] = num_steps_new
683 beta1_prod = builder.emit(multiply(TupleGetItem(state_var, 1), beta1), "beta1_prod")
684 beta2_prod = builder.emit(multiply(TupleGetItem(state_var, 2), beta2), "beta2_prod")
685 state_list_new[1] = beta1_prod
686 state_list_new[2] = beta2_prod
687
688 # computation logics
689 for i in range(len_param):

Callers

nothing calls this directly

Calls 15

functionMethod · 0.95
dataflowMethod · 0.95
emitMethod · 0.95
emit_outputMethod · 0.95
emit_func_outputMethod · 0.95
getMethod · 0.95
TupleStructInfoClass · 0.85
TensorStructInfoClass · 0.85
_high_precision_subtractFunction · 0.85
BlockBuilderClass · 0.85
TupleGetItemClass · 0.85
_check_initMethod · 0.80

Tested by

no test coverage detected