(
data_file: str = None,
dataset_type: str = "humanevalx",
)
| 208 | |
| 209 | |
| 210 | def read_dataset( |
| 211 | data_file: str = None, |
| 212 | dataset_type: str = "humanevalx", |
| 213 | ) -> Dict: |
| 214 | if "humanevalx" in dataset_type.lower(): |
| 215 | dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} |
| 216 | elif "mbpp" in dataset_type.lower(): |
| 217 | problems = {task["task_id"]: task for task in stream_jsonl(data_file)} |
| 218 | task_ids = sorted(problems.keys())[10:510] |
| 219 | dataset = {} |
| 220 | for task_id in task_ids: |
| 221 | sample = problems[task_id] |
| 222 | description = sample["text"] |
| 223 | test_example = sample["test_list"][0] |
| 224 | prompt = f'"""\n{description}\n{test_example}\n"""\n' |
| 225 | sample["prompt"] = prompt |
| 226 | dataset[task_id] = sample |
| 227 | elif "ds1000" in dataset_type.lower(): |
| 228 | # install ds1000 from https://github.com/HKUNLP/DS-1000 |
| 229 | from ds1000 import DS1000Dataset |
| 230 | ds1000 = DS1000Dataset(source_dir=data_file, libs="all", mode="Completion") |
| 231 | for lib in ds1000.libs: |
| 232 | for problem_id in range(len(ds1000[lib])): |
| 233 | prefix = "" |
| 234 | suffix = "" |
| 235 | insert_flag = False |
| 236 | first_line_flag = True |
| 237 | # extract prefix and suffix of the prompt |
| 238 | for line in ds1000[lib][problem_id]["prompt"].split("\n"): |
| 239 | if "[insert]" in line: |
| 240 | insert_flag = True |
| 241 | continue |
| 242 | if first_line_flag: |
| 243 | first_line_flag = False |
| 244 | else: |
| 245 | line = "\n" + line |
| 246 | if not insert_flag: |
| 247 | prefix += line |
| 248 | else: |
| 249 | suffix += line |
| 250 | |
| 251 | else: |
| 252 | raise f"Dataset: {dataset_type} not supported." |
| 253 | |
| 254 | return dataset |
| 255 | |
| 256 | |
| 257 | def read_translation_dataset( |
no test coverage detected