MCPcopy
hub / github.com/microsoft/Magma / generate_response

Function generate_response

server/main.py:63–110  ·  view source on GitHub ↗

Generate a response from the Magma model given an image and a prompt

(image, user_prompt)

Source from the content-addressed store, hash-verified

61
62
63def generate_response(image, user_prompt):
64 """Generate a response from the Magma model given an image and a prompt"""
65 convs = [
66 system_message,
67 {"role": "user", "content": f"<image>\n{user_prompt}"},
68 ]
69 prompt = processor.tokenizer.apply_chat_template(
70 convs, tokenize=False, add_generation_prompt=True
71 )
72 if model.config.mm_use_image_start_end:
73 prompt = prompt.replace("<image>", "<image_start><image><image_end>")
74
75 inputs = processor(images=[image], texts=prompt, return_tensors="pt")
76 inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0)
77 inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0)
78 inputs = inputs.to("cuda").to(dtype)
79
80 generation_args = {
81 "max_new_tokens": 500,
82 "temperature": 0.7, # Some temperature for diverse responses
83 "do_sample": True, # Enable sampling
84 "num_beams": 1,
85 "use_cache": True,
86 }
87
88 with torch.inference_mode():
89 generate_ids = model.generate(**inputs, **generation_args)
90
91 # For action IDs - extract the last 7 tokens (6 DOF + gripper)
92 action_ids = generate_ids[0, -8:-1].cpu().tolist()
93
94 # Convert to discretized actions
95 discretized_actions = processor.tokenizer.vocab_size - np.array(action_ids).astype(
96 np.int64
97 )
98 discretized_actions = np.clip(
99 discretized_actions - 1, a_min=0, a_max=bin_centers.shape[0] - 1
100 )
101 normalized_actions = bin_centers[discretized_actions]
102
103 # Convert normalized actions to actual delta values
104 delta_values = denormalize_actions(normalized_actions)
105
106 # For text response
107 text_ids = generate_ids[:, inputs["input_ids"].shape[-1] :]
108 response = processor.decode(text_ids[0], skip_special_tokens=True).strip()
109
110 return normalized_actions.tolist(), delta_values.tolist(), response
111
112
113@app.on_event("startup")

Callers 2

predictFunction · 0.85
predict_from_fileFunction · 0.85

Calls 2

denormalize_actionsFunction · 0.85
decodeMethod · 0.80

Tested by

no test coverage detected