MCPcopy
hub / github.com/ModelTC/LightLLM / VisionTransformer

Class VisionTransformer

lightllm/models/vit/model.py:28–204  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

26
27
28class VisionTransformer:
29
30 # weight class
31 pre_and_post_weight_class = ViTPreAndPostLayerWeight
32 transformer_weight_class = ViTTransformerLayerWeight
33
34 # infer class
35 pre_layer_infer_class = ViTPreLayerInfer
36 transformer_layer_infer_class = ViTTransformerLayerInfer
37 post_layer_infer_class = ViTPostLayerInfer
38
39 def __init__(self, kvargs):
40 self.tp_world_size_ = get_dp_world_size()
41 self.weight_dir_ = kvargs["weight_dir"]
42 self.load_way = kvargs.get("load_way", "HF")
43 self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])]
44 self.weight_dict = kvargs.get("weight_dict", None)
45 self.data_type = kvargs.get("data_type", "float16")
46 self.quant_type = kvargs.get("quant_type", None)
47 self.quant_cfg_path = kvargs.get("quant_cfg", None)
48 self.load_image_func = get_load_image_func(self.weight_dir_)
49 self.max_batch_size = kvargs.get("max_batch_size", 1)
50
51 self._init_datatype()
52 self._init_config()
53 self._padding_hidden_size()
54 self._init_quant()
55 self._init_weights()
56 self._init_infer_layer()
57 self._check_max_len_infer()
58 return
59
60 @final
61 @torch.no_grad()
62 def _check_max_len_infer(self):
63 disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None
64 if disable_check_max_len_infer:
65 return
66
67 try:
68 dummy_images = torch.randn(
69 (self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type
70 ).cuda()
71 all_img_embeds = self.forward(dummy_images)
72 del all_img_embeds
73 logger.info(f"vit check max_len {self.max_batch_size} infer ok")
74 except (RuntimeError, torch.OutOfMemoryError) as e:
75 logger.exception(str(e))
76 exception_str = (
77 "Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value."
78 )
79 logger.error(exception_str)
80 raise Exception(exception_str)
81 return
82
83 def _init_config(self):
84 with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file:
85 self.config = json.load(json_file)

Callers 2

exposed_init_modelMethod · 0.90
tppart_model_inferFunction · 0.90

Calls

no outgoing calls

Tested by 1

tppart_model_inferFunction · 0.72