(args, problems)
| 30 | |
| 31 | |
| 32 | def input_generation(args, problems): |
| 33 | with open(args.output, "w") as file: |
| 34 | for problem in problems.values(): |
| 35 | new_input = {} |
| 36 | task_id = problem["task_id"] |
| 37 | print(f"generating inputs for {task_id} ...") |
| 38 | # by default we do not include constraints in the prompt (code) |
| 39 | code = problem["prompt"] + problem["canonical_solution"] |
| 40 | # but we use c_code to include contract which checks input validity at execution time |
| 41 | if args.dataset == "humaneval": |
| 42 | c_code = ( |
| 43 | problem["prompt"] |
| 44 | + problem["contract"] |
| 45 | + problem["canonical_solution"] |
| 46 | ) |
| 47 | elif args.dataset == "mbpp": |
| 48 | c_code = problem["prompt"] + insert_contract_into_code( |
| 49 | entry_point=problem["entry_point"], |
| 50 | code=problem["canonical_solution"], |
| 51 | contract=problem["contract"], |
| 52 | ) |
| 53 | |
| 54 | # first generate chatgpt |
| 55 | input_gen = ChatGPTGen( |
| 56 | problem["base_input"], problem["entry_point"], c_code, code |
| 57 | ).generate(args.chatgpt_len) |
| 58 | # generate mutation next |
| 59 | |
| 60 | if input_gen is None or len(input_gen) == 0: |
| 61 | new_input["task_id"] = task_id |
| 62 | new_input["inputs"] = {} |
| 63 | file.write(json.dumps(new_input, cls=SetEncoder) + "\n") |
| 64 | continue |
| 65 | |
| 66 | input_gen.extend( |
| 67 | TypedMutGen(input_gen, problem["entry_point"], c_code).generate( |
| 68 | args.mut_len |
| 69 | ) |
| 70 | ) |
| 71 | print(f"generated {len(input_gen)} inputs") |
| 72 | new_input["task_id"] = task_id |
| 73 | if args.dataset == "mbpp": |
| 74 | new_input["inputs"] = mbpp_serialize_inputs(task_id, input_gen) |
| 75 | new_input["inputs"] = input_gen |
| 76 | file.write(json.dumps(new_input, cls=SetEncoder) + "\n") |
| 77 | |
| 78 | |
| 79 | def main(): |
no test coverage detected