MCPcopy Index your code
hub / github.com/tensorflow/models / UniformTaskSampler

Class UniformTaskSampler

official/modeling/multitask/task_sampler.py:50–62  ·  view source on GitHub ↗

Sample all tasks uniformly.

Source from the content-addressed store, hash-verified

48
49
50class UniformTaskSampler(TaskSampler):
51 """Sample all tasks uniformly."""
52
53 def __init__(self, task_weights: Dict[Text, Union[float, int]]):
54 super(UniformTaskSampler, self).__init__(task_weights=task_weights)
55 self._uniform_cumulative = tf.math.cumsum(
56 tf.constant(
57 [1.0 / len(self._task_weights)] * len(self._task_weights),
58 dtype=tf.float32))
59
60 def task_cumulative_distribution(self, global_step: tf.Tensor) -> tf.Tensor:
61 del global_step
62 return self._uniform_cumulative
63
64
65class ProportionalTaskSampler(TaskSampler):

Callers 1

get_task_samplerFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected