Utils to create task sampler with configuration and task weights.
(config: configs.TaskSamplingConfig,
task_weights: Dict[Text, float])
| 111 | |
| 112 | |
| 113 | def get_task_sampler(config: configs.TaskSamplingConfig, |
| 114 | task_weights: Dict[Text, float]) -> TaskSampler: |
| 115 | """Utils to create task sampler with configuration and task weights.""" |
| 116 | oneof_config = config.get() |
| 117 | if config.type == 'uniform': |
| 118 | return UniformTaskSampler(task_weights=task_weights) |
| 119 | elif config.type == 'proportional': |
| 120 | return ProportionalTaskSampler( |
| 121 | task_weights=task_weights, alpha=oneof_config.alpha) |
| 122 | elif config.type == 'annealing': |
| 123 | return AnnealingTaskSampler( |
| 124 | task_weights=task_weights, |
| 125 | steps_per_epoch=oneof_config.steps_per_epoch, |
| 126 | total_steps=oneof_config.total_steps) |
| 127 | else: |
| 128 | raise RuntimeError('Task sampler type not supported') |
nothing calls this directly
no test coverage detected