Generate a response from the Magma model given an image and a prompt
(image, user_prompt)
| 61 | |
| 62 | |
| 63 | def generate_response(image, user_prompt): |
| 64 | """Generate a response from the Magma model given an image and a prompt""" |
| 65 | convs = [ |
| 66 | system_message, |
| 67 | {"role": "user", "content": f"<image>\n{user_prompt}"}, |
| 68 | ] |
| 69 | prompt = processor.tokenizer.apply_chat_template( |
| 70 | convs, tokenize=False, add_generation_prompt=True |
| 71 | ) |
| 72 | if model.config.mm_use_image_start_end: |
| 73 | prompt = prompt.replace("<image>", "<image_start><image><image_end>") |
| 74 | |
| 75 | inputs = processor(images=[image], texts=prompt, return_tensors="pt") |
| 76 | inputs["pixel_values"] = inputs["pixel_values"].unsqueeze(0) |
| 77 | inputs["image_sizes"] = inputs["image_sizes"].unsqueeze(0) |
| 78 | inputs = inputs.to("cuda").to(dtype) |
| 79 | |
| 80 | generation_args = { |
| 81 | "max_new_tokens": 500, |
| 82 | "temperature": 0.7, # Some temperature for diverse responses |
| 83 | "do_sample": True, # Enable sampling |
| 84 | "num_beams": 1, |
| 85 | "use_cache": True, |
| 86 | } |
| 87 | |
| 88 | with torch.inference_mode(): |
| 89 | generate_ids = model.generate(**inputs, **generation_args) |
| 90 | |
| 91 | # For action IDs - extract the last 7 tokens (6 DOF + gripper) |
| 92 | action_ids = generate_ids[0, -8:-1].cpu().tolist() |
| 93 | |
| 94 | # Convert to discretized actions |
| 95 | discretized_actions = processor.tokenizer.vocab_size - np.array(action_ids).astype( |
| 96 | np.int64 |
| 97 | ) |
| 98 | discretized_actions = np.clip( |
| 99 | discretized_actions - 1, a_min=0, a_max=bin_centers.shape[0] - 1 |
| 100 | ) |
| 101 | normalized_actions = bin_centers[discretized_actions] |
| 102 | |
| 103 | # Convert normalized actions to actual delta values |
| 104 | delta_values = denormalize_actions(normalized_actions) |
| 105 | |
| 106 | # For text response |
| 107 | text_ids = generate_ids[:, inputs["input_ids"].shape[-1] :] |
| 108 | response = processor.decode(text_ids[0], skip_special_tokens=True).strip() |
| 109 | |
| 110 | return normalized_actions.tolist(), delta_values.tolist(), response |
| 111 | |
| 112 | |
| 113 | @app.on_event("startup") |
no test coverage detected