MCPcopy
hub / github.com/jingyaogong/minimind / SFTDataset

Class SFTDataset

dataset/lm_dataset.py:58–119  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

56
57
58class SFTDataset(Dataset):
59 def __init__(self, jsonl_path, tokenizer, max_length=1024):
60 super().__init__()
61 self.tokenizer = tokenizer
62 self.max_length = max_length
63 features = Features({'conversations': [{'role': Value('string'), 'content': Value('string'), 'reasoning_content': Value('string'), 'tools': Value('string'), 'tool_calls': Value('string')}]})
64 self.samples = load_dataset('json', data_files=jsonl_path, split='train', features=features)
65 self.bos_id = tokenizer(f'{tokenizer.bos_token}assistant\n', add_special_tokens=False).input_ids
66 self.eos_id = tokenizer(f'{tokenizer.eos_token}\n', add_special_tokens=False).input_ids
67
68 def __len__(self):
69 return len(self.samples)
70
71 def create_chat_prompt(self, conversations):
72 messages = []
73 tools = None
74 for message in conversations:
75 message = dict(message)
76 if message.get("role") == "system" and message.get("tools"):
77 tools = json.loads(message["tools"]) if isinstance(message["tools"], str) else message["tools"]
78 if message.get("tool_calls") and isinstance(message["tool_calls"], str):
79 message["tool_calls"] = json.loads(message["tool_calls"])
80 messages.append(message)
81 return self.tokenizer.apply_chat_template(
82 messages,
83 tokenize=False,
84 add_generation_prompt=False,
85 tools=tools
86 )
87
88 def generate_labels(self, input_ids):
89 labels = [-100] * len(input_ids)
90 i = 0
91 while i < len(input_ids):
92 if input_ids[i:i + len(self.bos_id)] == self.bos_id:
93 start = i + len(self.bos_id)
94 end = start
95 while end < len(input_ids):
96 if input_ids[end:end + len(self.eos_id)] == self.eos_id:
97 break
98 end += 1
99 for j in range(start, min(end + len(self.eos_id), self.max_length)):
100 labels[j] = input_ids[j]
101 i = end + len(self.eos_id) if end < len(input_ids) else len(input_ids)
102 else:
103 i += 1
104 return labels
105
106 def __getitem__(self, index):
107 sample = self.samples[index]
108 conversations = pre_processing_chat(sample['conversations'])
109 prompt = self.create_chat_prompt(conversations)
110 prompt = post_processing_chat(prompt)
111 input_ids = self.tokenizer(prompt).input_ids[:self.max_length]
112 input_ids += [self.tokenizer.pad_token_id] * (self.max_length - len(input_ids))
113 labels = self.generate_labels(input_ids)
114 # # === 调试打印 ===
115 # print(f"\n--- Sample {index} ---")

Callers 3

train_lora.pyFile · 0.90
train_full_sft.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected