MCPcopy
hub / github.com/policy-gradient/GRPO-Zero / CountdownTasksDataset

Class CountdownTasksDataset

countdown_task.py:24–80  ·  view source on GitHub ↗

Prepare Countdown Tasks for training

Source from the content-addressed store, hash-verified

22
23
24class CountdownTasksDataset(Dataset):
25 """Prepare Countdown Tasks for training"""
26
27 def __init__(
28 self,
29 tokenizer: Tokenizer,
30 data_path: str,
31 split: str = "train",
32 test_size: int = 100,
33 ):
34 data = pd.read_parquet(Path(data_path) / "data")
35 # use the last `test_size` examples for testing
36 self.data = (
37 data.iloc[:-test_size] if split == "train" else data.iloc[-test_size:]
38 )
39 self.tokenizer = tokenizer
40
41 def __len__(self):
42 return len(self.data)
43
44 def __getitem__(self, idx):
45 item = self.data.iloc[idx].to_dict()
46 item.update(self.encode_prefix(item["nums"], item["target"]))
47 return item
48
49 def encode_prefix(self, numbers: List[int], target: int):
50 """Prefix is the *actual* input to the model."""
51 user_message = USER_TEMPLATE.format(numbers=numbers, target=target)
52 prefix = self.tokenizer.encode_chat_with_response_prompt(
53 [
54 {"role": "system", "content": SYSTEM_MESSAGE},
55 {"role": "user", "content": user_message},
56 ],
57 RESPONSE_PROMPT,
58 )
59 tokens = self.tokenizer.tokenize(prefix)
60 return {
61 "prefix": prefix,
62 "prefix_tokens": tokens.tokens,
63 "prefix_token_ids": tokens.ids,
64 }
65
66 @staticmethod
67 def collate_fn(batch: List[Dict[str, Any]]) -> MiniBatch:
68 """Collate examples into a batch."""
69 numbers = [item["nums"] for item in batch]
70 target = [item["target"] for item in batch]
71 prefix = [item["prefix"] for item in batch]
72 prefix_tokens = [item["prefix_tokens"] for item in batch]
73 prefix_token_ids = [item["prefix_token_ids"] for item in batch]
74 return MiniBatch(
75 numbers=numbers,
76 target=target,
77 prefix=prefix,
78 prefix_tokens=prefix_tokens,
79 prefix_token_ids=prefix_token_ids,
80 )
81

Callers 2

evaluateFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected