| 24 | |
| 25 | |
| 26 | class Model: |
| 27 | def __init__(self, device, dtype, **kwargs): |
| 28 | self.device = device |
| 29 | self.dtype = dtype |
| 30 | self.generator = torch.Generator(device=device) |
| 31 | self.pipe_dict = { |
| 32 | ModelType.Pix2Pix_Video: StableDiffusionInstructPix2PixPipeline, |
| 33 | ModelType.Text2Video: TextToVideoPipeline, |
| 34 | ModelType.ControlNetCanny: StableDiffusionControlNetPipeline, |
| 35 | ModelType.ControlNetCannyDB: StableDiffusionControlNetPipeline, |
| 36 | ModelType.ControlNetPose: StableDiffusionControlNetPipeline, |
| 37 | ModelType.ControlNetDepth: StableDiffusionControlNetPipeline, |
| 38 | } |
| 39 | self.controlnet_attn_proc = utils.CrossFrameAttnProcessor( |
| 40 | unet_chunk_size=2) |
| 41 | self.pix2pix_attn_proc = utils.CrossFrameAttnProcessor( |
| 42 | unet_chunk_size=3) |
| 43 | self.text2video_attn_proc = utils.CrossFrameAttnProcessor( |
| 44 | unet_chunk_size=2) |
| 45 | |
| 46 | self.pipe = None |
| 47 | self.model_type = None |
| 48 | |
| 49 | self.states = {} |
| 50 | self.model_name = "" |
| 51 | |
| 52 | def set_model(self, model_type: ModelType, model_id: str, **kwargs): |
| 53 | if hasattr(self, "pipe") and self.pipe is not None: |
| 54 | del self.pipe |
| 55 | self.pipe = None |
| 56 | torch.cuda.empty_cache() |
| 57 | gc.collect() |
| 58 | safety_checker = kwargs.pop('safety_checker', None) |
| 59 | self.pipe = self.pipe_dict[model_type].from_pretrained( |
| 60 | model_id, safety_checker=safety_checker, **kwargs).to(self.device).to(self.dtype) |
| 61 | self.model_type = model_type |
| 62 | self.model_name = model_id |
| 63 | |
| 64 | def inference_chunk(self, frame_ids, **kwargs): |
| 65 | if not hasattr(self, "pipe") or self.pipe is None: |
| 66 | return |
| 67 | |
| 68 | prompt = np.array(kwargs.pop('prompt')) |
| 69 | negative_prompt = np.array(kwargs.pop('negative_prompt', '')) |
| 70 | latents = None |
| 71 | if 'latents' in kwargs: |
| 72 | latents = kwargs.pop('latents')[frame_ids] |
| 73 | if 'image' in kwargs: |
| 74 | kwargs['image'] = kwargs['image'][frame_ids] |
| 75 | if 'video_length' in kwargs: |
| 76 | kwargs['video_length'] = len(frame_ids) |
| 77 | if self.model_type == ModelType.Text2Video: |
| 78 | kwargs["frame_ids"] = frame_ids |
| 79 | return self.pipe(prompt=prompt[frame_ids].tolist(), |
| 80 | negative_prompt=negative_prompt[frame_ids].tolist(), |
| 81 | latents=latents, |
| 82 | generator=self.generator, |
| 83 | **kwargs) |