| 58 | |
| 59 | |
| 60 | def sanitize( |
| 61 | old_code: str, |
| 62 | entry_point: str, |
| 63 | rm_prefix_lines: Optional[str] = None, |
| 64 | eofs: List = None, |
| 65 | ): |
| 66 | new_code = old_code |
| 67 | if rm_prefix_lines is not None: |
| 68 | new_code = "\n".join( |
| 69 | [ |
| 70 | line |
| 71 | for line in old_code.splitlines() |
| 72 | if not line.startswith(rm_prefix_lines) |
| 73 | ] |
| 74 | ) |
| 75 | |
| 76 | new_code = "\n" + new_code |
| 77 | def_left = "def " + entry_point |
| 78 | |
| 79 | # basic handling of chat output |
| 80 | new_code = new_code.replace("\n```python\n", "\n```\n") |
| 81 | for chunk in new_code.split("\n```\n"): |
| 82 | if def_left in chunk: |
| 83 | new_code = chunk |
| 84 | break |
| 85 | |
| 86 | chunks = [chunk for chunk in re.split(f"{def_left}\\s*\\(", new_code)] |
| 87 | # TODO: having return does not mean this is complete |
| 88 | bodies = [chunk for chunk in chunks[1:] if " return " in chunk.split("\ndef")[0]] |
| 89 | def_left = def_left + "(" |
| 90 | new_code = def_left + def_left.join(bodies) if len(bodies) > 0 else "" # fn + impl |
| 91 | new_code = to_four_space_indents(new_code) |
| 92 | |
| 93 | for eof in eofs or []: |
| 94 | new_code = new_code.split(eof)[0] |
| 95 | |
| 96 | # remove lines starting from the first unindented line after def_left |
| 97 | new_code = remove_unindented_lines( |
| 98 | new_code, |
| 99 | protect_before=def_left, |
| 100 | execeptions=["def ", "import ", "from "], |
| 101 | trim_tails=['"""', "if", "print"], |
| 102 | ) |
| 103 | new_code = chunks[0] + new_code |
| 104 | |
| 105 | # cut all functions that are not syntactically correct && not the entry point |
| 106 | parts = new_code.split("\ndef ") |
| 107 | includes = [parts[0]] |
| 108 | for fn in new_code.split("\ndef ")[1:]: |
| 109 | if ( |
| 110 | fn.strip().startswith(entry_point + " ") |
| 111 | or fn.strip().startswith(entry_point + "(") |
| 112 | or syntax_check("\ndef " + fn) |
| 113 | ): |
| 114 | includes.append(fn) |
| 115 | new_code = "\ndef ".join(includes) |
| 116 | return new_code.strip() |
| 117 | |