| 53 | |
| 54 | |
| 55 | class LLaVACaptioner(Captioner): |
| 56 | |
| 57 | def __init__( |
| 58 | self, device: torch.device, llava_bit: Literal["16", "8", "4"] |
| 59 | ) -> "LLaVACaptioner": |
| 60 | super().__init__(device) |
| 61 | if llava_bit == "16": |
| 62 | load_4bit, load_8bit = False, False |
| 63 | elif llava_bit == "8": |
| 64 | load_4bit, load_8bit = False, True |
| 65 | else: |
| 66 | load_4bit, load_8bit = True, False |
| 67 | |
| 68 | model_path = "liuhaotian/llava-v1.5-7b" |
| 69 | model_name = get_model_name_from_path(model_path) |
| 70 | device_map = {"": device} |
| 71 | self.tokenizer, self.model, self.image_processor, context_len = ( |
| 72 | load_pretrained_model( |
| 73 | model_path, |
| 74 | None, |
| 75 | model_name, |
| 76 | device=device, |
| 77 | device_map=device_map, |
| 78 | load_4bit=load_4bit, |
| 79 | load_8bit=load_8bit, |
| 80 | ) |
| 81 | ) |
| 82 | self.model.eval() |
| 83 | |
| 84 | qs = "Please give me a very short description of this image." |
| 85 | image_token_se = ( |
| 86 | DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN |
| 87 | ) |
| 88 | if IMAGE_PLACEHOLDER in qs: |
| 89 | if self.model.config.mm_use_im_start_end: |
| 90 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) |
| 91 | else: |
| 92 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) |
| 93 | else: |
| 94 | if self.model.config.mm_use_im_start_end: |
| 95 | qs = image_token_se + "\n" + qs |
| 96 | else: |
| 97 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs |
| 98 | |
| 99 | if "llama-2" in model_name.lower(): |
| 100 | conv_mode = "llava_llama_2" |
| 101 | elif "mistral" in model_name.lower(): |
| 102 | conv_mode = "mistral_instruct" |
| 103 | elif "v1.6-34b" in model_name.lower(): |
| 104 | conv_mode = "chatml_direct" |
| 105 | elif "v1" in model_name.lower(): |
| 106 | conv_mode = "llava_v1" |
| 107 | elif "mpt" in model_name.lower(): |
| 108 | conv_mode = "mpt" |
| 109 | else: |
| 110 | conv_mode = "llava_v0" |
| 111 | |
| 112 | conv = conv_templates[conv_mode].copy() |
no outgoing calls
no test coverage detected