| 56 | |
| 57 | |
| 58 | class SFTDataset(Dataset): |
| 59 | def __init__(self, jsonl_path, tokenizer, max_length=1024): |
| 60 | super().__init__() |
| 61 | self.tokenizer = tokenizer |
| 62 | self.max_length = max_length |
| 63 | features = Features({'conversations': [{'role': Value('string'), 'content': Value('string'), 'reasoning_content': Value('string'), 'tools': Value('string'), 'tool_calls': Value('string')}]}) |
| 64 | self.samples = load_dataset('json', data_files=jsonl_path, split='train', features=features) |
| 65 | self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids |
| 66 | self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids |
| 67 | |
| 68 | def __len__(self): |
| 69 | return len(self.samples) |
| 70 | |
| 71 | def create_chat_prompt(self, conversations): |
| 72 | messages = [] |
| 73 | tools = None |
| 74 | for message in conversations: |
| 75 | message = dict(message) |
| 76 | if message.get("role") == "system" and message.get("tools"): |
| 77 | tools = json.loads(message["tools"]) if isinstance(message["tools"], str) else message["tools"] |
| 78 | if message.get("tool_calls") and isinstance(message["tool_calls"], str): |
| 79 | message["tool_calls"] = json.loads(message["tool_calls"]) |
| 80 | messages.append(message) |
| 81 | return self.tokenizer.apply_chat_template( |
| 82 | messages, |
| 83 | tokenize=False, |
| 84 | add_generation_prompt=False, |
| 85 | tools=tools |
| 86 | ) |
| 87 | |
| 88 | def generate_labels(self, input_ids): |
| 89 | labels = [-100] * len(input_ids) |
| 90 | i = 0 |
| 91 | while i < len(input_ids): |
| 92 | if input_ids[i:i + len(self.bos_id)] == self.bos_id: |
| 93 | start = i + len(self.bos_id) |
| 94 | end = start |
| 95 | while end < len(input_ids): |
| 96 | if input_ids[end:end + len(self.eos_id)] == self.eos_id: |
| 97 | break |
| 98 | end += 1 |
| 99 | for j in range(start, min(end + len(self.eos_id), self.max_length)): |
| 100 | labels[j] = input_ids[j] |
| 101 | i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids) |
| 102 | else: |
| 103 | i += 1 |
| 104 | return labels |
| 105 | |
| 106 | def __getitem__(self, index): |
| 107 | sample = self.samples[index] |
| 108 | conversations = pre_processing_chat(sample['conversations']) |
| 109 | prompt = self.create_chat_prompt(conversations) |
| 110 | prompt = post_processing_chat(prompt) |
| 111 | input_ids = self.tokenizer(prompt).input_ids[:self.max_length] |
| 112 | input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids)) |
| 113 | labels = self.generate_labels(input_ids) |
| 114 | # # === 调试打印 === |
| 115 | # print(f"\n--- Sample {index} ---") |
no outgoing calls
no test coverage detected