Cleans up the generated code.
(
code: str,
dataset_type: str = None,
language_type: str = None,
)
| 371 | |
| 372 | |
| 373 | def cleanup_code( |
| 374 | code: str, |
| 375 | dataset_type: str = None, |
| 376 | language_type: str = None, |
| 377 | ): |
| 378 | """ |
| 379 | Cleans up the generated code. |
| 380 | """ |
| 381 | if dataset_type == "mbpp": |
| 382 | end_words = ["\nassert", "\ndef"] |
| 383 | for w in end_words: |
| 384 | if w == "\ndef": |
| 385 | if code.count(w) > 1: |
| 386 | code = code[:code.rfind(w)] |
| 387 | else: |
| 388 | code = code[:code.rfind(w)] |
| 389 | code = first_block(code, stop_words) |
| 390 | elif dataset_type == "humanevalx": |
| 391 | if language_type.lower() == "python": |
| 392 | code_splits = code.split("\n") |
| 393 | is_empty_line = False |
| 394 | ind_empty_line = None |
| 395 | for i, line in enumerate(code_splits): |
| 396 | if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t': |
| 397 | is_empty_line = True |
| 398 | ind_empty_line = i |
| 399 | break |
| 400 | if is_empty_line: |
| 401 | code = "\n".join(code_splits[:ind_empty_line]) |
| 402 | else: |
| 403 | end_words = ["\ndef", "\nclass", "\n#", "\nassert", '\n"""', "\nprint", "\nif", "\n\n\n"] |
| 404 | for w in end_words: |
| 405 | if w in code: |
| 406 | code = code[:code.rfind(w)] |
| 407 | elif language_type.lower() == "java": |
| 408 | main_pos = code.find("public static void main") |
| 409 | if main_pos != -1: |
| 410 | code = code[:main_pos] + '}' |
| 411 | if '}' in code: |
| 412 | code = code[:code.rfind('}')] + '}' |
| 413 | if code.count('{') + 1 == code.count('}'): |
| 414 | code += "\n}" |
| 415 | elif language_type.lower() == "go": |
| 416 | if "\nfunc main(" in code: |
| 417 | code = code[:code.rfind("func main(")] |
| 418 | if '}' in code: |
| 419 | code = code[:code.rfind('}')] + '}' |
| 420 | elif language_type.lower() == "cpp": |
| 421 | if "\nint main()" in code: |
| 422 | code = code[:code.rfind("int main()")] |
| 423 | if '}' in code: |
| 424 | code = code[:code.rfind('}')] + '}' |
| 425 | elif language_type.lower() == "js": |
| 426 | if '}' in code: |
| 427 | code = code[:code.rfind('}')] + '}' |
| 428 | elif language_type.lower() == "rust": |
| 429 | if '}' in code: |
| 430 | code = code[:code.rfind('}')] + '}' |
no test coverage detected