| 26 | |
| 27 | |
| 28 | class VisionTransformer: |
| 29 | |
| 30 | # weight class |
| 31 | pre_and_post_weight_class = ViTPreAndPostLayerWeight |
| 32 | transformer_weight_class = ViTTransformerLayerWeight |
| 33 | |
| 34 | # infer class |
| 35 | pre_layer_infer_class = ViTPreLayerInfer |
| 36 | transformer_layer_infer_class = ViTTransformerLayerInfer |
| 37 | post_layer_infer_class = ViTPostLayerInfer |
| 38 | |
| 39 | def __init__(self, kvargs): |
| 40 | self.tp_world_size_ = get_dp_world_size() |
| 41 | self.weight_dir_ = kvargs["weight_dir"] |
| 42 | self.load_way = kvargs.get("load_way", "HF") |
| 43 | self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])] |
| 44 | self.weight_dict = kvargs.get("weight_dict", None) |
| 45 | self.data_type = kvargs.get("data_type", "float16") |
| 46 | self.quant_type = kvargs.get("quant_type", None) |
| 47 | self.quant_cfg_path = kvargs.get("quant_cfg", None) |
| 48 | self.load_image_func = get_load_image_func(self.weight_dir_) |
| 49 | self.max_batch_size = kvargs.get("max_batch_size", 1) |
| 50 | |
| 51 | self._init_datatype() |
| 52 | self._init_config() |
| 53 | self._padding_hidden_size() |
| 54 | self._init_quant() |
| 55 | self._init_weights() |
| 56 | self._init_infer_layer() |
| 57 | self._check_max_len_infer() |
| 58 | return |
| 59 | |
| 60 | @final |
| 61 | @torch.no_grad() |
| 62 | def _check_max_len_infer(self): |
| 63 | disable_check_max_len_infer = os.getenv("DISABLE_CHECK_MAX_LEN_INFER", None) is not None |
| 64 | if disable_check_max_len_infer: |
| 65 | return |
| 66 | |
| 67 | try: |
| 68 | dummy_images = torch.randn( |
| 69 | (self.MAX_PATH_NUM * self.max_batch_size, 3, self.IMAGE_H, self.IMAGE_W), dtype=self.data_type |
| 70 | ).cuda() |
| 71 | all_img_embeds = self.forward(dummy_images) |
| 72 | del all_img_embeds |
| 73 | logger.info(f"vit check max_len {self.max_batch_size} infer ok") |
| 74 | except (RuntimeError, torch.OutOfMemoryError) as e: |
| 75 | logger.exception(str(e)) |
| 76 | exception_str = ( |
| 77 | "Vit check max len infer fail, you can try:" "1.Set the --visual_infer_batch_size to a smaller value." |
| 78 | ) |
| 79 | logger.error(exception_str) |
| 80 | raise Exception(exception_str) |
| 81 | return |
| 82 | |
| 83 | def _init_config(self): |
| 84 | with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: |
| 85 | self.config = json.load(json_file) |
no outgoing calls