MCPcopy
hub / github.com/microsoft/agent-lightning / apo_algorithm

Function apo_algorithm

examples/apo/apo_custom_algorithm.py:42–99  ·  view source on GitHub ↗

An example of how a prompt optimization works.

(*, store: agl.LightningStore)

Source from the content-addressed store, hash-verified

40
41
42async def apo_algorithm(*, store: agl.LightningStore):
43 """
44 An example of how a prompt optimization works.
45 """
46 prompt_candidates = [
47 "You are a helpful assistant. {any_question}",
48 "You are a knowledgeable AI. {any_question}",
49 "You are a friendly chatbot. {any_question}",
50 ]
51
52 prompt_and_rewards: list[tuple[str, float]] = []
53
54 algo_marker = "[bold red][Algo][/bold red]"
55
56 for prompt in prompt_candidates:
57 # 1. The optimization algorithm updates the prompt template
58 console.print(f"\n{algo_marker} Updating prompt template to: '{prompt}'")
59 resources: agl.NamedResources = {
60 # The "main_prompt" can be replaced with any name you like
61 # As long as the PromptTemplate type is used, the rollout function will recognize it
62 "main_prompt": agl.PromptTemplate(template=prompt, engine="f-string")
63 }
64 # How the resource is used fully depends on the client implementation.
65 await store.add_resources(resources)
66
67 # 2. The algorithm queues up a task from a dataset
68 console.print(f"{algo_marker} Queuing task for clients...")
69 rollout = await store.enqueue_rollout(
70 input="Explain why the sky appears blue using principles of light scattering in 100 words.", mode="train"
71 )
72 console.print(f"{algo_marker} Task '{rollout.rollout_id}' is now available for clients.")
73
74 # 3. The algorithm waits for clients to process the task
75 for _ in range(30): # Wait for at most 30 seconds
76 rollouts = await store.wait_for_rollouts(rollout_ids=[rollout.rollout_id], timeout=0.01)
77 if rollouts:
78 break
79 await asyncio.sleep(1.0)
80 else:
81 raise RuntimeError("Expected a completed rollout from the client, but got none.")
82
83 console.print(f"{algo_marker} Received Result: {rollouts[0]}")
84 if rollouts[0].status != "succeeded":
85 raise RuntimeError(f"Rollout {rollout.rollout_id} did not succeed. Status: {rollouts[0].status}")
86 spans = await store.query_spans(rollout.rollout_id)
87
88 # Logs LLM spans for debugging and inspection here
89 await log_llm_span(spans)
90
91 # 4. The algorithm records the final reward for sorting
92 final_reward = agl.find_final_reward(spans)
93 assert final_reward is not None, "Expected a final reward from the client."
94 console.print(f"{algo_marker} Final reward: {final_reward}")
95 prompt_and_rewards.append((prompt, final_reward))
96
97 console.print(f"\n[bold red][Algo][/bold red] All prompts and their rewards: {prompt_and_rewards}")
98 best_prompt = max(prompt_and_rewards, key=lambda x: x[1])
99 console.print(f"[bold red][Algo][/bold red] Best prompt found: '{best_prompt[0]}' with reward {best_prompt[1]}")

Callers 2

mainFunction · 0.85

Calls 5

log_llm_spanFunction · 0.85
add_resourcesMethod · 0.45
enqueue_rolloutMethod · 0.45
wait_for_rolloutsMethod · 0.45
query_spansMethod · 0.45

Tested by

no test coverage detected