MultiTask initialization. Args: tasks: a list or a flat dict of Task. task_weights: a dict of (task, task weight), task weight can be applied directly during loss summation in a joint backward step, or it can be used to sample task among interleaved backward step.
(self,
tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]],
task_weights: Optional[Dict[str, Union[float, int]]] = None,
task_eval_steps: Optional[Dict[str, int]] = None,
name: Optional[str] = None)
| 34 | """A multi-task class to manage multiple tasks.""" |
| 35 | |
| 36 | def __init__(self, |
| 37 | tasks: Union[Dict[Text, base_task.Task], List[base_task.Task]], |
| 38 | task_weights: Optional[Dict[str, Union[float, int]]] = None, |
| 39 | task_eval_steps: Optional[Dict[str, int]] = None, |
| 40 | name: Optional[str] = None): |
| 41 | """MultiTask initialization. |
| 42 | |
| 43 | Args: |
| 44 | tasks: a list or a flat dict of Task. |
| 45 | task_weights: a dict of (task, task weight), task weight can be applied |
| 46 | directly during loss summation in a joint backward step, or it can be |
| 47 | used to sample task among interleaved backward step. |
| 48 | task_eval_steps: a dict of (task, eval steps). |
| 49 | name: the instance name of a MultiTask object. |
| 50 | """ |
| 51 | super().__init__(name=name) |
| 52 | if isinstance(tasks, list): |
| 53 | self._tasks = {} |
| 54 | for task in tasks: |
| 55 | if task.name in self._tasks: |
| 56 | raise ValueError("Duplicated tasks found, task.name is %s" % |
| 57 | task.name) |
| 58 | self._tasks[task.name] = task |
| 59 | elif isinstance(tasks, dict): |
| 60 | self._tasks = tasks |
| 61 | else: |
| 62 | raise ValueError("The tasks argument has an invalid type: %s" % |
| 63 | type(tasks)) |
| 64 | self.task_eval_steps = task_eval_steps or {} |
| 65 | self._task_weights = task_weights or {} |
| 66 | self._task_weights = dict([ |
| 67 | (name, self._task_weights.get(name, 1.0)) for name in self.tasks |
| 68 | ]) |
| 69 | |
| 70 | @classmethod |
| 71 | def from_config(cls, config: configs.MultiTaskConfig, logging_dir=None): |