MCPcopy Index your code
hub / github.com/tensorflow/models / __init__

Method __init__

official/modeling/multitask/multitask.py:36–68  ·  view source on GitHub ↗

MultiTask initialization. Args: tasks: a list or a flat dict of Task. task_weights: a dict of (task, task weight), task weight can be applied directly during loss summation in a joint backward step, or it can be used to sample task among interleaved backward step.

(self,
               tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
               task_weights: Optional[Dict[str, Union[float, int]]] = None,
               task_eval_steps: Optional[Dict[str, int]] = None,
               name: Optional[str] = None)

Source from the content-addressed store, hash-verified

34 """A multi-task class to manage multiple tasks."""
35
36 def __init__(self,
37 tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
38 task_weights: Optional[Dict[str, Union[float, int]]] = None,
39 task_eval_steps: Optional[Dict[str, int]] = None,
40 name: Optional[str] = None):
41 """MultiTask initialization.
42
43 Args:
44 tasks: a list or a flat dict of Task.
45 task_weights: a dict of (task, task weight), task weight can be applied
46 directly during loss summation in a joint backward step, or it can be
47 used to sample task among interleaved backward step.
48 task_eval_steps: a dict of (task, eval steps).
49 name: the instance name of a MultiTask object.
50 """
51 super().__init__(name=name)
52 if isinstance(tasks, list):
53 self._tasks = {}
54 for task in tasks:
55 if task.name in self._tasks:
56 raise ValueError("Duplicated tasks found, task.name is %s" %
57 task.name)
58 self._tasks[task.name] = task
59 elif isinstance(tasks, dict):
60 self._tasks = tasks
61 else:
62 raise ValueError("The tasks argument has an invalid type: %s" %
63 type(tasks))
64 self.task_eval_steps = task_eval_steps or {}
65 self._task_weights = task_weights or {}
66 self._task_weights = dict([
67 (name, self._task_weights.get(name, 1.0)) for name in self.tasks
68 ])
69
70 @classmethod
71 def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None):

Callers

nothing calls this directly

Calls 1

getMethod · 0.45

Tested by

no test coverage detected