MCPcopy
hub / github.com/microsoft/Magma / forward

Method forward

magma/modeling_magma.py:1167–1361  ·  view source on GitHub ↗

r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Toke

(
        self,
        input_ids: torch.LongTensor = None,
        pixel_values: torch.FloatTensor = None,
        image_sizes: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layer: Optional[int] = None,
        vision_feature_select_strategy: Optional[str] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    )

Source from the content-addressed store, hash-verified

1165 @add_start_docstrings_to_model_forward(MAGMA_INPUTS_DOCSTRING)
1166 @replace_return_docstrings(output_type=MagmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1167 def forward(
1168 self,
1169 input_ids: torch.LongTensor = None,
1170 pixel_values: torch.FloatTensor = None,
1171 image_sizes: Optional[torch.LongTensor] = None,
1172 attention_mask: Optional[torch.Tensor] = None,
1173 position_ids: Optional[torch.LongTensor] = None,
1174 past_key_values: Optional[List[torch.FloatTensor]] = None,
1175 inputs_embeds: Optional[torch.FloatTensor] = None,
1176 vision_feature_layer: Optional[int] = None,
1177 vision_feature_select_strategy: Optional[str] = None,
1178 labels: Optional[torch.LongTensor] = None,
1179 use_cache: Optional[bool] = None,
1180 output_attentions: Optional[bool] = None,
1181 output_hidden_states: Optional[bool] = None,
1182 return_dict: Optional[bool] = None,
1183 ) -> Union[Tuple, MagmaCausalLMOutputWithPast]:
1184 r"""
1185 Args:
1186 labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1187 Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1188 config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1189 (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1190
1191 Returns:
1192
1193 Example:
1194
1195 ```python
1196 >>> from PIL import Image
1197 >>> import requests
1198 >>> from transformers import AutoProcessor, MagmaForConditionalGeneration
1199
1200 >>> model = MagmaForConditionalGeneration.from_pretrained("microsoft/magma-8b-hf")
1201 >>> processor = AutoProcessor.from_pretrained("microsoft/magma-8b-hf")
1202
1203 >>> prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
1204 >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1205 >>> image = Image.open(requests.get(url, stream=True).raw)
1206
1207 >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
1208
1209 >>> # Generate
1210 >>> generate_ids = model.generate(**inputs, max_length=30)
1211 >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1212 "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
1213 ```"""
1214 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1215 output_hidden_states = (
1216 output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1217 )
1218 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1219 vision_feature_layer = (
1220 vision_feature_layer if vision_feature_layer is not None else self.config.vision_config['vision_feature_layer']
1221 )
1222
1223 if inputs_embeds is None:
1224 # 1. Extract the input embeddings

Callers

nothing calls this directly

Calls 4

get_input_embeddingsMethod · 0.95
flattenMethod · 0.80

Tested by

no test coverage detected