MCPcopy
hub / github.com/OpenNMT/OpenNMT-py / __init__

Method __init__

onmt/inference_engine.py:221–259  ·  view source on GitHub ↗
(self, opt)

Source from the content-addressed store, hash-verified

219 """
220
221 def __init__(self, opt):
222 import ctranslate2
223 import pyonmttok
224
225 super().__init__(opt)
226 self.opt = opt
227 self.logger = init_logger(opt.log_file)
228 assert self.opt.world_size <= 1, "World size must be less than 1."
229 self.device_id = opt.gpu
230 if opt.world_size == 1:
231 self.device_index = opt.gpu_ranks
232 self.device = "cuda"
233 else:
234 self.device_index = 0
235 self.device = "cpu"
236 self.transforms_cls = get_transforms_cls(self.opt._all_transform)
237 # Build translator
238 if opt.model_task == ModelTask.LANGUAGE_MODEL:
239 self.translator = ctranslate2.Generator(
240 opt.models[0], device=self.device, device_index=self.device_index
241 )
242 else:
243 self.translator = ctranslate2.Translator(
244 self.opt.models[0], device=self.device, device_index=self.device_index
245 )
246 # Build vocab
247 vocab_path = opt.src_subword_vocab
248 with open(vocab_path, "r") as f:
249 vocab = json.load(f)
250 vocabs = {}
251 src_vocab = pyonmttok.build_vocab_from_tokens(vocab)
252 vocabs["src"] = src_vocab
253 vocabs["tgt"] = src_vocab
254 vocabs["data_task"] = "lm"
255 vocabs["decoder_start_token"] = "<s>"
256 self.vocabs = vocabs
257 # Build transform pipe
258 transforms = make_transforms(opt, self.transforms_cls, self.vocabs)
259 self.transform = TransformPipe.build_from(transforms.values())
260
261 def translate_batch(self, batch, opt):
262 input_tokens = []

Callers

nothing calls this directly

Calls 6

init_loggerFunction · 0.90
get_transforms_clsFunction · 0.90
make_transformsFunction · 0.90
build_fromMethod · 0.80
__init__Method · 0.45
loadMethod · 0.45

Tested by

no test coverage detected