An example of how a prompt optimization works.
(*, store: agl.LightningStore)
| 40 | |
| 41 | |
| 42 | async def apo_algorithm(*, store: agl.LightningStore): |
| 43 | """ |
| 44 | An example of how a prompt optimization works. |
| 45 | """ |
| 46 | prompt_candidates = [ |
| 47 | "You are a helpful assistant. {any_question}", |
| 48 | "You are a knowledgeable AI. {any_question}", |
| 49 | "You are a friendly chatbot. {any_question}", |
| 50 | ] |
| 51 | |
| 52 | prompt_and_rewards: list[tuple[str, float]] = [] |
| 53 | |
| 54 | algo_marker = "[bold red][Algo][/bold red]" |
| 55 | |
| 56 | for prompt in prompt_candidates: |
| 57 | # 1. The optimization algorithm updates the prompt template |
| 58 | console.print(f"\n{algo_marker} Updating prompt template to: '{prompt}'") |
| 59 | resources: agl.NamedResources = { |
| 60 | # The "main_prompt" can be replaced with any name you like |
| 61 | # As long as the PromptTemplate type is used, the rollout function will recognize it |
| 62 | "main_prompt": agl.PromptTemplate(template=prompt, engine="f-string") |
| 63 | } |
| 64 | # How the resource is used fully depends on the client implementation. |
| 65 | await store.add_resources(resources) |
| 66 | |
| 67 | # 2. The algorithm queues up a task from a dataset |
| 68 | console.print(f"{algo_marker} Queuing task for clients...") |
| 69 | rollout = await store.enqueue_rollout( |
| 70 | input="Explain why the sky appears blue using principles of light scattering in 100 words.", mode="train" |
| 71 | ) |
| 72 | console.print(f"{algo_marker} Task '{rollout.rollout_id}' is now available for clients.") |
| 73 | |
| 74 | # 3. The algorithm waits for clients to process the task |
| 75 | for _ in range(30): # Wait for at most 30 seconds |
| 76 | rollouts = await store.wait_for_rollouts(rollout_ids=[rollout.rollout_id], timeout=0.01) |
| 77 | if rollouts: |
| 78 | break |
| 79 | await asyncio.sleep(1.0) |
| 80 | else: |
| 81 | raise RuntimeError("Expected a completed rollout from the client, but got none.") |
| 82 | |
| 83 | console.print(f"{algo_marker} Received Result: {rollouts[0]}") |
| 84 | if rollouts[0].status != "succeeded": |
| 85 | raise RuntimeError(f"Rollout {rollout.rollout_id} did not succeed. Status: {rollouts[0].status}") |
| 86 | spans = await store.query_spans(rollout.rollout_id) |
| 87 | |
| 88 | # Logs LLM spans for debugging and inspection here |
| 89 | await log_llm_span(spans) |
| 90 | |
| 91 | # 4. The algorithm records the final reward for sorting |
| 92 | final_reward = agl.find_final_reward(spans) |
| 93 | assert final_reward is not None, "Expected a final reward from the client." |
| 94 | console.print(f"{algo_marker} Final reward: {final_reward}") |
| 95 | prompt_and_rewards.append((prompt, final_reward)) |
| 96 | |
| 97 | console.print(f"\n[bold red][Algo][/bold red] All prompts and their rewards: {prompt_and_rewards}") |
| 98 | best_prompt = max(prompt_and_rewards, key=lambda x: x[1]) |
| 99 | console.print(f"[bold red][Algo][/bold red] Best prompt found: '{best_prompt[0]}' with reward {best_prompt[1]}") |
no test coverage detected