Sample tasks proportional to task weights.
| 63 | |
| 64 | |
| 65 | class ProportionalTaskSampler(TaskSampler): |
| 66 | """Sample tasks proportional to task weights.""" |
| 67 | |
| 68 | def __init__(self, |
| 69 | task_weights: Dict[Text, Union[float, int]], |
| 70 | alpha: float = 1.0): |
| 71 | super(ProportionalTaskSampler, self).__init__(task_weights=task_weights) |
| 72 | self._alpha = tf.cast(alpha, dtype=tf.float32) |
| 73 | task_weight_dict_ordered_list = tf.constant( |
| 74 | [weight for _, weight in self._task_weights.items()], dtype=tf.float32) |
| 75 | task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha) |
| 76 | task_distribution = task_sizes / tf.reduce_sum(task_sizes) |
| 77 | self._porportional_cumulative = tf.math.cumsum(task_distribution) |
| 78 | |
| 79 | def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor: |
| 80 | del global_step |
| 81 | return self._porportional_cumulative |
| 82 | |
| 83 | |
| 84 | class AnnealingTaskSampler(TaskSampler): |