(self, opt)
| 219 | """ |
| 220 | |
| 221 | def __init__(self, opt): |
| 222 | import ctranslate2 |
| 223 | import pyonmttok |
| 224 | |
| 225 | super().__init__(opt) |
| 226 | self.opt = opt |
| 227 | self.logger = init_logger(opt.log_file) |
| 228 | assert self.opt.world_size <= 1, "World size must be less than 1." |
| 229 | self.device_id = opt.gpu |
| 230 | if opt.world_size == 1: |
| 231 | self.device_index = opt.gpu_ranks |
| 232 | self.device = "cuda" |
| 233 | else: |
| 234 | self.device_index = 0 |
| 235 | self.device = "cpu" |
| 236 | self.transforms_cls = get_transforms_cls(self.opt._all_transform) |
| 237 | # Build translator |
| 238 | if opt.model_task == ModelTask.LANGUAGE_MODEL: |
| 239 | self.translator = ctranslate2.Generator( |
| 240 | opt.models[0], device=self.device, device_index=self.device_index |
| 241 | ) |
| 242 | else: |
| 243 | self.translator = ctranslate2.Translator( |
| 244 | self.opt.models[0], device=self.device, device_index=self.device_index |
| 245 | ) |
| 246 | # Build vocab |
| 247 | vocab_path = opt.src_subword_vocab |
| 248 | with open(vocab_path, "r") as f: |
| 249 | vocab = json.load(f) |
| 250 | vocabs = {} |
| 251 | src_vocab = pyonmttok.build_vocab_from_tokens(vocab) |
| 252 | vocabs["src"] = src_vocab |
| 253 | vocabs["tgt"] = src_vocab |
| 254 | vocabs["data_task"] = "lm" |
| 255 | vocabs["decoder_start_token"] = "<s>" |
| 256 | self.vocabs = vocabs |
| 257 | # Build transform pipe |
| 258 | transforms = make_transforms(opt, self.transforms_cls, self.vocabs) |
| 259 | self.transform = TransformPipe.build_from(transforms.values()) |
| 260 | |
| 261 | def translate_batch(self, batch, opt): |
| 262 | input_tokens = [] |
nothing calls this directly
no test coverage detected