MCPcopy Index your code
hub / github.com/apache/tvm / get_function

Method get_function

python/tvm/relax/training/optimizer.py:452–514  ·  view source on GitHub ↗

Use blockbuilder to construct an optimizer function that executes updates of the parameters and the optimizer state. `init()` should be called before `get_function()`. Returns ------- func : Function The optimizer function.

(self)

Source from the content-addressed store, hash-verified

450 return self
451
452 def get_function(self) -> Function:
453 """Use blockbuilder to construct an optimizer function that executes updates of the
454 parameters and the optimizer state. `init()` should be called before `get_function()`.
455
456 Returns
457 -------
458 func : Function
459 The optimizer function.
460 """
461 self._check_init()
462 plist = self.param_list
463 len_param = len(plist)
464 dtype = self.dtype
465
466 # input variables
467 param_var = Var("params", TupleStructInfo([p.struct_info for p in plist]))
468 grad_var = Var("gradients", TupleStructInfo([p.struct_info for p in plist]))
469 state_var = Var(
470 "optim_states",
471 TupleStructInfo([TensorStructInfo((), "int64"), *(p.struct_info for p in plist)]),
472 )
473
474 # constants
475 lr = const(self.lr, dtype)
476 momentum = const(self.momentum, dtype)
477 weight_decay = const(self.weight_decay, dtype)
478 dampening_inv = const(_high_precision_subtract(1, self.dampening), dtype)
479 one = const(1, "int64")
480
481 builder = BlockBuilder()
482 with builder.function(self.name, [param_var, grad_var, state_var]):
483 with builder.dataflow():
484 param_list_new, state_list_new = [], []
485
486 # handle num_steps
487 num_steps = builder.emit(TupleGetItem(state_var, 0), "num_steps")
488 num_steps_new = builder.emit(add(num_steps, one), "num_steps_new")
489 state_list_new.append(num_steps_new)
490
491 # computation logics
492 for i in range(len_param):
493 name = self.param_list[i].name_hint
494 p = builder.emit(TupleGetItem(param_var, i), name)
495 g = builder.emit(TupleGetItem(grad_var, i), name + "_grad")
496 v = builder.emit(TupleGetItem(state_var, i + 1), name + "_v")
497 if self.weight_decay:
498 g = builder.emit(add(multiply(weight_decay, p), g), name + "_grad_new")
499 damp_g = multiply(dampening_inv, g) if self.dampening else g
500 v_new = builder.emit(add(multiply(momentum, v), damp_g), name + "_v_new")
501 g_new = (
502 builder.emit(add(g, multiply(momentum, v_new)), name + "_g_nest")
503 if self.nesterov
504 else v_new
505 )
506 p_new = builder.emit(subtract(p, multiply(lr, g_new)), name + "_new")
507 param_list_new.append(p_new)
508 state_list_new.append(v_new)
509

Callers

nothing calls this directly

Calls 15

functionMethod · 0.95
dataflowMethod · 0.95
emitMethod · 0.95
emit_outputMethod · 0.95
emit_func_outputMethod · 0.95
getMethod · 0.95
TupleStructInfoClass · 0.85
TensorStructInfoClass · 0.85
_high_precision_subtractFunction · 0.85
BlockBuilderClass · 0.85
TupleGetItemClass · 0.85
_check_initMethod · 0.80

Tested by

no test coverage detected