MCPcopy
hub / github.com/Picsart-AI-Research/Text2Video-Zero / Model

Class Model

model.py:26–495  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

24
25
26class Model:
27 def __init__(self, device, dtype, **kwargs):
28 self.device = device
29 self.dtype = dtype
30 self.generator = torch.Generator(device=device)
31 self.pipe_dict = {
32 ModelType.Pix2Pix_Video: StableDiffusionInstructPix2PixPipeline,
33 ModelType.Text2Video: TextToVideoPipeline,
34 ModelType.ControlNetCanny: StableDiffusionControlNetPipeline,
35 ModelType.ControlNetCannyDB: StableDiffusionControlNetPipeline,
36 ModelType.ControlNetPose: StableDiffusionControlNetPipeline,
37 ModelType.ControlNetDepth: StableDiffusionControlNetPipeline,
38 }
39 self.controlnet_attn_proc = utils.CrossFrameAttnProcessor(
40 unet_chunk_size=2)
41 self.pix2pix_attn_proc = utils.CrossFrameAttnProcessor(
42 unet_chunk_size=3)
43 self.text2video_attn_proc = utils.CrossFrameAttnProcessor(
44 unet_chunk_size=2)
45
46 self.pipe = None
47 self.model_type = None
48
49 self.states = {}
50 self.model_name = ""
51
52 def set_model(self, model_type: ModelType, model_id: str, **kwargs):
53 if hasattr(self, "pipe") and self.pipe is not None:
54 del self.pipe
55 self.pipe = None
56 torch.cuda.empty_cache()
57 gc.collect()
58 safety_checker = kwargs.pop('safety_checker', None)
59 self.pipe = self.pipe_dict[model_type].from_pretrained(
60 model_id, safety_checker=safety_checker, **kwargs).to(self.device).to(self.dtype)
61 self.model_type = model_type
62 self.model_name = model_id
63
64 def inference_chunk(self, frame_ids, **kwargs):
65 if not hasattr(self, "pipe") or self.pipe is None:
66 return
67
68 prompt = np.array(kwargs.pop('prompt'))
69 negative_prompt = np.array(kwargs.pop('negative_prompt', ''))
70 latents = None
71 if 'latents' in kwargs:
72 latents = kwargs.pop('latents')[frame_ids]
73 if 'image' in kwargs:
74 kwargs['image'] = kwargs['image'][frame_ids]
75 if 'video_length' in kwargs:
76 kwargs['video_length'] = len(frame_ids)
77 if self.model_type == ModelType.Text2Video:
78 kwargs["frame_ids"] = frame_ids
79 return self.pipe(prompt=prompt[frame_ids].tolist(),
80 negative_prompt=negative_prompt[frame_ids].tolist(),
81 latents=latents,
82 generator=self.generator,
83 **kwargs)

Callers 1

app.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected