Prepare Countdown Tasks for training
| 22 | |
| 23 | |
| 24 | class CountdownTasksDataset(Dataset): |
| 25 | """Prepare Countdown Tasks for training""" |
| 26 | |
| 27 | def __init__( |
| 28 | self, |
| 29 | tokenizer: Tokenizer, |
| 30 | data_path: str, |
| 31 | split: str = "train", |
| 32 | test_size: int = 100, |
| 33 | ): |
| 34 | data = pd.read_parquet(Path(data_path) / "data") |
| 35 | # use the last `test_size` examples for testing |
| 36 | self.data = ( |
| 37 | data.iloc[:-test_size] if split == "train" else data.iloc[-test_size:] |
| 38 | ) |
| 39 | self.tokenizer = tokenizer |
| 40 | |
| 41 | def __len__(self): |
| 42 | return len(self.data) |
| 43 | |
| 44 | def __getitem__(self, idx): |
| 45 | item = self.data.iloc[idx].to_dict() |
| 46 | item.update(self.encode_prefix(item["nums"], item["target"])) |
| 47 | return item |
| 48 | |
| 49 | def encode_prefix(self, numbers: List[int], target: int): |
| 50 | """Prefix is the *actual* input to the model.""" |
| 51 | user_message = USER_TEMPLATE.format(numbers=numbers, target=target) |
| 52 | prefix = self.tokenizer.encode_chat_with_response_prompt( |
| 53 | [ |
| 54 | {"role": "system", "content": SYSTEM_MESSAGE}, |
| 55 | {"role": "user", "content": user_message}, |
| 56 | ], |
| 57 | RESPONSE_PROMPT, |
| 58 | ) |
| 59 | tokens = self.tokenizer.tokenize(prefix) |
| 60 | return { |
| 61 | "prefix": prefix, |
| 62 | "prefix_tokens": tokens.tokens, |
| 63 | "prefix_token_ids": tokens.ids, |
| 64 | } |
| 65 | |
| 66 | @staticmethod |
| 67 | def collate_fn(batch: List[Dict[str, Any]]) -> MiniBatch: |
| 68 | """Collate examples into a batch.""" |
| 69 | numbers = [item["nums"] for item in batch] |
| 70 | target = [item["target"] for item in batch] |
| 71 | prefix = [item["prefix"] for item in batch] |
| 72 | prefix_tokens = [item["prefix_tokens"] for item in batch] |
| 73 | prefix_token_ids = [item["prefix_token_ids"] for item in batch] |
| 74 | return MiniBatch( |
| 75 | numbers=numbers, |
| 76 | target=target, |
| 77 | prefix=prefix, |
| 78 | prefix_tokens=prefix_tokens, |
| 79 | prefix_token_ids=prefix_token_ids, |
| 80 | ) |
| 81 |