| 220 | ) |
| 221 | |
| 222 | async def apply(self): |
| 223 | current_prompt = "What is AI?" |
| 224 | passed_guard = False |
| 225 | for _ in range(max(self.max_prompts, 1)): |
| 226 | # Fetch prompts from the API |
| 227 | prompts = await asyncio.to_thread( |
| 228 | lambda: self.rl_model.select_next_prompts( |
| 229 | current_prompt, passed_guard=passed_guard |
| 230 | ) |
| 231 | ) |
| 232 | |
| 233 | if not prompts: |
| 234 | logger.error("No prompts retrieved from the API.") |
| 235 | return |
| 236 | |
| 237 | logger.info(f"Retrieved {len(prompts)} prompts.") |
| 238 | |
| 239 | for i, prompt in enumerate(prompts): |
| 240 | logger.info(f"Processing prompt {i+1}/{len(prompts)}: {prompt}") |
| 241 | yield prompt |
| 242 | current_prompt = prompt |
| 243 | while not self.tools_inbox.empty(): |
| 244 | ref = await self.tools_inbox.get() |
| 245 | print(ref, "ref") |
| 246 | message, _, ready = ref["message"], ref["reply"], ref["ready"] |
| 247 | yield message |
| 248 | ready.set() |