MCPcopy
hub / github.com/XPixelGroup/DiffBIR / __init__

Method __init__

diffbir/utils/caption.py:57–120  ·  view source on GitHub ↗
(
        self, device: torch.device, llava_bit: Literal["16", "8", "4"]
    )

Source from the content-addressed store, hash-verified

55class LLaVACaptioner(Captioner):
56
57 def __init__(
58 self, device: torch.device, llava_bit: Literal["16", "8", "4"]
59 ) -> "LLaVACaptioner":
60 super().__init__(device)
61 if llava_bit == "16":
62 load_4bit, load_8bit = False, False
63 elif llava_bit == "8":
64 load_4bit, load_8bit = False, True
65 else:
66 load_4bit, load_8bit = True, False
67
68 model_path = "liuhaotian/llava-v1.5-7b"
69 model_name = get_model_name_from_path(model_path)
70 device_map = {"": device}
71 self.tokenizer, self.model, self.image_processor, context_len = (
72 load_pretrained_model(
73 model_path,
74 None,
75 model_name,
76 device=device,
77 device_map=device_map,
78 load_4bit=load_4bit,
79 load_8bit=load_8bit,
80 )
81 )
82 self.model.eval()
83
84 qs = "Please give me a very short description of this image."
85 image_token_se = (
86 DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
87 )
88 if IMAGE_PLACEHOLDER in qs:
89 if self.model.config.mm_use_im_start_end:
90 qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
91 else:
92 qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
93 else:
94 if self.model.config.mm_use_im_start_end:
95 qs = image_token_se + "\n" + qs
96 else:
97 qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
98
99 if "llama-2" in model_name.lower():
100 conv_mode = "llava_llama_2"
101 elif "mistral" in model_name.lower():
102 conv_mode = "mistral_instruct"
103 elif "v1.6-34b" in model_name.lower():
104 conv_mode = "chatml_direct"
105 elif "v1" in model_name.lower():
106 conv_mode = "llava_v1"
107 elif "mpt" in model_name.lower():
108 conv_mode = "mpt"
109 else:
110 conv_mode = "llava_v0"
111
112 conv = conv_templates[conv_mode].copy()
113 conv.append_message(conv.roles[0], qs)
114 conv.append_message(conv.roles[1], None)

Callers

nothing calls this directly

Calls 6

get_model_name_from_pathFunction · 0.90
load_pretrained_modelFunction · 0.90
copyMethod · 0.80
append_messageMethod · 0.80
get_promptMethod · 0.80
__init__Method · 0.45

Tested by

no test coverage detected