(
self, train_dataset: Optional[agl.Dataset[Any]] = None, val_dataset: Optional[agl.Dataset[Any]] = None
)
| 149 | self.concurrency = concurrency |
| 150 | |
| 151 | async def run( |
| 152 | self, train_dataset: Optional[agl.Dataset[Any]] = None, val_dataset: Optional[agl.Dataset[Any]] = None |
| 153 | ): |
| 154 | if self.mode == "batch": |
| 155 | assert self.batch_size is not None |
| 156 | await self.algorithm_batch(self.total_tasks, self.batch_size) |
| 157 | elif self.mode == "batch_partial": |
| 158 | assert self.batch_size is not None |
| 159 | assert self.remaining_tasks is not None |
| 160 | await self.algorithm_batch_with_completion_threshold( |
| 161 | self.total_tasks, self.batch_size, self.remaining_tasks |
| 162 | ) |
| 163 | elif self.mode == "single": |
| 164 | assert self.concurrency is not None |
| 165 | await self.algorithm_batch_single(self.total_tasks, self.concurrency) |
| 166 | else: |
| 167 | raise ValueError(f"Invalid mode: {self.mode}") |
| 168 | |
| 169 | async def algorithm_batch(self, total_tasks: int, batch_size: int): |
| 170 | """ |
no test coverage detected