MCPcopy Index your code
hub / github.com/tensorflow/models / ProportionalTaskSampler

Class ProportionalTaskSampler

official/modeling/multitask/task_sampler.py:65–81  ·  view source on GitHub ↗

Sample tasks proportional to task weights.

Source from the content-addressed store, hash-verified

63
64
65class ProportionalTaskSampler(TaskSampler):
66 """Sample tasks proportional to task weights."""
67
68 def __init__(self,
69 task_weights: Dict[Text, Union[float, int]],
70 alpha: float = 1.0):
71 super(ProportionalTaskSampler, self).__init__(task_weights=task_weights)
72 self._alpha = tf.cast(alpha, dtype=tf.float32)
73 task_weight_dict_ordered_list = tf.constant(
74 [weight for _, weight in self._task_weights.items()], dtype=tf.float32)
75 task_sizes = tf.math.pow(task_weight_dict_ordered_list, self._alpha)
76 task_distribution = task_sizes / tf.reduce_sum(task_sizes)
77 self._porportional_cumulative = tf.math.cumsum(task_distribution)
78
79 def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
80 del global_step
81 return self._porportional_cumulative
82
83
84class AnnealingTaskSampler(TaskSampler):

Callers 1

get_task_samplerFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected