Async helper to set up data and resources on the server.
(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True)
| 498 | pass |
| 499 | |
| 500 | async def _async_set_up(self, data: Dict[str, Any], server_addresses: List[str], is_train: bool = True): |
| 501 | """Async helper to set up data and resources on the server.""" |
| 502 | self.clear_data_and_server() |
| 503 | if server_addresses != self.backend_llm_server_addresses: |
| 504 | self.backend_llm_server_addresses = server_addresses |
| 505 | if self.mode == "v1" and not self.llm_proxy.is_running(): |
| 506 | await self._update_proxy_server_v1() |
| 507 | self.is_train = is_train |
| 508 | |
| 509 | # 1. Update resources on the server for clients to use |
| 510 | if self.mode == "v0": |
| 511 | llm_resource = LLM( |
| 512 | endpoint=f"http://127.0.0.1:{self.proxy_port}/v1", |
| 513 | model=self.train_information.get("model", "default-model"), |
| 514 | sampling_parameters={ |
| 515 | "temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0) |
| 516 | }, |
| 517 | ) |
| 518 | else: |
| 519 | llm_resource = self.llm_proxy.as_resource( |
| 520 | sampling_parameters={ |
| 521 | "temperature": self.train_information.get("temperature", 0.7 if is_train else 0.0) |
| 522 | }, |
| 523 | ) |
| 524 | |
| 525 | resources: NamedResources = {"main_llm": llm_resource} |
| 526 | |
| 527 | if self.mode == "v0": |
| 528 | resources_id = await self.server.update_resources(resources) |
| 529 | else: |
| 530 | resources_update = await self.store.add_resources(resources) |
| 531 | resources_id = resources_update.resources_id |
| 532 | |
| 533 | # 2. Queue tasks for agents to process |
| 534 | keys = list(data.keys()) |
| 535 | num_samples = len(data[keys[0]]) |
| 536 | rollouts_per_sample = self.train_rollout_n if is_train else 1 |
| 537 | |
| 538 | enqueue_rollout_requests: List[EnqueueRolloutRequest] = [] |
| 539 | data_id_to_original_sample: Dict[str, Dict[str, Any]] = {} |
| 540 | |
| 541 | for i in range(num_samples): |
| 542 | data_id = str(uuid.uuid4()) |
| 543 | original_sample = {key: data[key][i] for key in keys} |
| 544 | original_sample["data_id"] = data_id |
| 545 | data_id_to_original_sample[data_id] = original_sample |
| 546 | |
| 547 | # For training, each sample is rolled out multiple times |
| 548 | # Data ID is different from Rollout ID, as one data can have multiple rollouts. |
| 549 | for _ in range(rollouts_per_sample): |
| 550 | task_metadata = {"data_id": data_id, "is_train": is_train} |
| 551 | if self.mode == "v0": |
| 552 | # Queue immediately |
| 553 | rollout_id = await self.server.queue_task( |
| 554 | sample=_to_native(original_sample), |
| 555 | mode="train" if is_train else "val", |
| 556 | resources_id=resources_id, |
| 557 | metadata=task_metadata, |
no test coverage detected