MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX2 / cleanup_code

Function cleanup_code

evaluation/utils.py:373–432  ·  view source on GitHub ↗

Cleans up the generated code.

(
    code: str,
    dataset_type: str = None,
    language_type: str = None,
)

Source from the content-addressed store, hash-verified

371
372
373def cleanup_code(
374 code: str,
375 dataset_type: str = None,
376 language_type: str = None,
377):
378 """
379 Cleans up the generated code.
380 """
381 if dataset_type == "mbpp":
382 end_words = ["\nassert", "\ndef"]
383 for w in end_words:
384 if w == "\ndef":
385 if code.count(w) > 1:
386 code = code[:code.rfind(w)]
387 else:
388 code = code[:code.rfind(w)]
389 code = first_block(code, stop_words)
390 elif dataset_type == "humanevalx":
391 if language_type.lower() == "python":
392 code_splits = code.split("\n")
393 is_empty_line = False
394 ind_empty_line = None
395 for i, line in enumerate(code_splits):
396 if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t':
397 is_empty_line = True
398 ind_empty_line = i
399 break
400 if is_empty_line:
401 code = "\n".join(code_splits[:ind_empty_line])
402 else:
403 end_words = ["\ndef", "\nclass", "\n#", "\nassert", '\n"""', "\nprint", "\nif", "\n\n\n"]
404 for w in end_words:
405 if w in code:
406 code = code[:code.rfind(w)]
407 elif language_type.lower() == "java":
408 main_pos = code.find("public static void main")
409 if main_pos != -1:
410 code = code[:main_pos] + '}'
411 if '}' in code:
412 code = code[:code.rfind('}')] + '}'
413 if code.count('{') + 1 == code.count('}'):
414 code += "\n}"
415 elif language_type.lower() == "go":
416 if "\nfunc main(" in code:
417 code = code[:code.rfind("func main(")]
418 if '}' in code:
419 code = code[:code.rfind('}')] + '}'
420 elif language_type.lower() == "cpp":
421 if "\nint main()" in code:
422 code = code[:code.rfind("int main()")]
423 if '}' in code:
424 code = code[:code.rfind('}')] + '}'
425 elif language_type.lower() == "js":
426 if '}' in code:
427 code = code[:code.rfind('}')] + '}'
428 elif language_type.lower() == "rust":
429 if '}' in code:
430 code = code[:code.rfind('}')] + '}'

Callers 1

processFunction · 0.90

Calls 1

first_blockFunction · 0.85

Tested by

no test coverage detected