(self, input_ids, infer_state, input_ids1=None, infer_state1=None)
| 176 | return graph_model_output, graph_model_output1 |
| 177 | |
| 178 | def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None): |
| 179 | if self.enable_decode_microbatch_overlap: |
| 180 | return self._replay_overlap(input_ids, infer_state, input_ids1, infer_state1) |
| 181 | else: |
| 182 | assert input_ids1 is None and infer_state1 is None |
| 183 | return self._replay(input_ids, infer_state) |
| 184 | |
| 185 | @torch.no_grad() |
| 186 | def warmup(self, model): |