(
response: requests.Response,
processor: "VaeImageProcessor" | "VideoProcessor" | None = None,
output_type: Literal["mp4", "pil", "pt"] = "pil",
return_type: Literal["mp4", "pil", "pt"] = "pil",
partial_postprocess: bool = False,
)
| 92 | |
| 93 | |
| 94 | def postprocess_decode( |
| 95 | response: requests.Response, |
| 96 | processor: "VaeImageProcessor" | "VideoProcessor" | None = None, |
| 97 | output_type: Literal["mp4", "pil", "pt"] = "pil", |
| 98 | return_type: Literal["mp4", "pil", "pt"] = "pil", |
| 99 | partial_postprocess: bool = False, |
| 100 | ): |
| 101 | if output_type == "pt" or (output_type == "pil" and processor is not None): |
| 102 | output_tensor = response.content |
| 103 | parameters = response.headers |
| 104 | shape = json.loads(parameters["shape"]) |
| 105 | dtype = parameters["dtype"] |
| 106 | torch_dtype = DTYPE_MAP[dtype] |
| 107 | output_tensor = torch.frombuffer(bytearray(output_tensor), dtype=torch_dtype).reshape(shape) |
| 108 | if output_type == "pt": |
| 109 | if partial_postprocess: |
| 110 | if return_type == "pil": |
| 111 | output = [Image.fromarray(image.numpy()) for image in output_tensor] |
| 112 | if len(output) == 1: |
| 113 | output = output[0] |
| 114 | elif return_type == "pt": |
| 115 | output = output_tensor |
| 116 | else: |
| 117 | if processor is None or return_type == "pt": |
| 118 | output = output_tensor |
| 119 | else: |
| 120 | if isinstance(processor, VideoProcessor): |
| 121 | output = cast( |
| 122 | list[Image.Image], |
| 123 | processor.postprocess_video(output_tensor, output_type="pil")[0], |
| 124 | ) |
| 125 | else: |
| 126 | output = cast( |
| 127 | Image.Image, |
| 128 | processor.postprocess(output_tensor, output_type="pil")[0], |
| 129 | ) |
| 130 | elif output_type == "pil" and return_type == "pil" and processor is None: |
| 131 | output = Image.open(io.BytesIO(response.content)).convert("RGB") |
| 132 | detected_format = detect_image_type(response.content) |
| 133 | output.format = detected_format |
| 134 | elif output_type == "pil" and processor is not None: |
| 135 | if return_type == "pil": |
| 136 | output = [ |
| 137 | Image.fromarray(image) |
| 138 | for image in (output_tensor.permute(0, 2, 3, 1).float().numpy() * 255).round().astype("uint8") |
| 139 | ] |
| 140 | elif return_type == "pt": |
| 141 | output = output_tensor |
| 142 | elif output_type == "mp4" and return_type == "mp4": |
| 143 | output = response.content |
| 144 | return output |
| 145 | |
| 146 | |
| 147 | def prepare_decode( |
no test coverage detected
searching dependent graphs…