Create a ScheduleSampler from a library of pre-defined samplers. :param name: the name of the sampler. :param diffusion: the diffusion object to sample for.
(name, diffusion)
| 6 | |
| 7 | |
| 8 | def create_named_schedule_sampler(name, diffusion): |
| 9 | """ |
| 10 | Create a ScheduleSampler from a library of pre-defined samplers. |
| 11 | |
| 12 | :param name: the name of the sampler. |
| 13 | :param diffusion: the diffusion object to sample for. |
| 14 | """ |
| 15 | if name == "uniform": |
| 16 | return UniformSampler(diffusion) |
| 17 | elif name == "loss-second-moment": |
| 18 | return LossSecondMomentResampler(diffusion) |
| 19 | else: |
| 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") |
| 21 | |
| 22 | |
| 23 | class ScheduleSampler(ABC): |
no test coverage detected