Sample all tasks uniformly.
| 48 | |
| 49 | |
| 50 | class UniformTaskSampler(TaskSampler): |
| 51 | """Sample all tasks uniformly.""" |
| 52 | |
| 53 | def __init__(self, task_weights: Dict[Text, Union[float, int]]): |
| 54 | super(UniformTaskSampler, self).__init__(task_weights=task_weights) |
| 55 | self._uniform_cumulative = tf.math.cumsum( |
| 56 | tf.constant( |
| 57 | [1.0 / len(self._task_weights)] * len(self._task_weights), |
| 58 | dtype=tf.float32)) |
| 59 | |
| 60 | def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: |
| 61 | del global_step |
| 62 | return self._uniform_cumulative |
| 63 | |
| 64 | |
| 65 | class ProportionalTaskSampler(TaskSampler): |