MCPcopy
hub / github.com/PaddlePaddle/ERNIE / End2EndProcessor

Class End2EndProcessor

data_processor/steps/end2end_processing/processor.py:33–107  ·  view source on GitHub ↗

End2End Processor

Source from the content-addressed store, hash-verified

31
32
33class End2EndProcessor(ProcessorBase):
34 """
35 End2End Processor
36 """
37
38 def __init__(self, args, tokenizer, image_preprocess):
39 """
40 init
41 """
42 super().__init__(args)
43 self.utterance_process = UtteranceProcessor(args[0], tokenizer)
44 self.coarse_processor = CoarseProcessor(args[1])
45 self.input_ids_massage_processor = InputIdsMassageProcessor(args[2], tokenizer, image_preprocess)
46 self.pseudo_multiround_processor = PseudoMultiRoundProcessor(args[3], tokenizer)
47 self.image_modification_processor = ImageModificationProcessor(args[4], tokenizer, image_preprocess)
48 self.array_collation_processor = ArrayCollationProcessor(args[5], tokenizer)
49 self.batch_size = args[6].batch_size
50 self.load_args_from_api = args[6].load_args_from_api
51
52 def process(self, data, **kwargs):
53 """
54 process
55 """
56 # step1: utterance processing
57 if not self.is_training and self.load_args_from_api:
58 generation_config = copy.deepcopy(data)
59 if "context" in generation_config:
60 del generation_config["context"]
61 kwargs.update(generation_config)
62 schema = self.utterance_process.process(data, **kwargs)
63
64 # step2: coarse processing
65 schema = self.coarse_processor.process(schema, **kwargs)
66
67 # step3: ids massaging
68 schemas = self.input_ids_massage_processor.process(schema, **kwargs)
69
70 # step4: multiround processing
71 if not self.is_training:
72 results = schemas
73 if self.input_ids_massage_processor.args.serialize_output:
74 results = json.loads(str(results))
75 assert len(results) == 1
76 rets = self.pseudo_multiround_processor.process(schemas, **kwargs)
77
78 # step5: image modification
79 tensor = []
80 for ret in rets:
81 tensor.append(self.image_modification_processor.process(ret, **kwargs))
82
83 return tensor
84
85 def collate(self, batch):
86 """
87 collate fn
88 """
89 return self.array_collation_processor.collate(batch)
90

Callers 1

_init_mm_modelMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected