End2End Processor
| 31 | |
| 32 | |
| 33 | class End2EndProcessor(ProcessorBase): |
| 34 | """ |
| 35 | End2End Processor |
| 36 | """ |
| 37 | |
| 38 | def __init__(self, args, tokenizer, image_preprocess): |
| 39 | """ |
| 40 | init |
| 41 | """ |
| 42 | super().__init__(args) |
| 43 | self.utterance_process = UtteranceProcessor(args[0], tokenizer) |
| 44 | self.coarse_processor = CoarseProcessor(args[1]) |
| 45 | self.input_ids_massage_processor = InputIdsMassageProcessor(args[2], tokenizer, image_preprocess) |
| 46 | self.pseudo_multiround_processor = PseudoMultiRoundProcessor(args[3], tokenizer) |
| 47 | self.image_modification_processor = ImageModificationProcessor(args[4], tokenizer, image_preprocess) |
| 48 | self.array_collation_processor = ArrayCollationProcessor(args[5], tokenizer) |
| 49 | self.batch_size = args[6].batch_size |
| 50 | self.load_args_from_api = args[6].load_args_from_api |
| 51 | |
| 52 | def process(self, data, **kwargs): |
| 53 | """ |
| 54 | process |
| 55 | """ |
| 56 | # step1: utterance processing |
| 57 | if not self.is_training and self.load_args_from_api: |
| 58 | generation_config = copy.deepcopy(data) |
| 59 | if "context" in generation_config: |
| 60 | del generation_config["context"] |
| 61 | kwargs.update(generation_config) |
| 62 | schema = self.utterance_process.process(data, **kwargs) |
| 63 | |
| 64 | # step2: coarse processing |
| 65 | schema = self.coarse_processor.process(schema, **kwargs) |
| 66 | |
| 67 | # step3: ids massaging |
| 68 | schemas = self.input_ids_massage_processor.process(schema, **kwargs) |
| 69 | |
| 70 | # step4: multiround processing |
| 71 | if not self.is_training: |
| 72 | results = schemas |
| 73 | if self.input_ids_massage_processor.args.serialize_output: |
| 74 | results = json.loads(str(results)) |
| 75 | assert len(results) == 1 |
| 76 | rets = self.pseudo_multiround_processor.process(schemas, **kwargs) |
| 77 | |
| 78 | # step5: image modification |
| 79 | tensor = [] |
| 80 | for ret in rets: |
| 81 | tensor.append(self.image_modification_processor.process(ret, **kwargs)) |
| 82 | |
| 83 | return tensor |
| 84 | |
| 85 | def collate(self, batch): |
| 86 | """ |
| 87 | collate fn |
| 88 | """ |
| 89 | return self.array_collation_processor.collate(batch) |
| 90 |