| 30 | |
| 31 | |
| 32 | class TrtUnet(sd_unet.SdUnet): |
| 33 | def __init__(self, model_name: str, configs: List[dict], *args, **kwargs): |
| 34 | super().__init__(*args, **kwargs) |
| 35 | |
| 36 | self.stream = None |
| 37 | self.model_name = model_name |
| 38 | self.configs = configs |
| 39 | |
| 40 | self.profile_idx = 0 |
| 41 | self.loaded_config = None |
| 42 | |
| 43 | self.engine_vram_req = 0 |
| 44 | self.refitted_keys = set() |
| 45 | |
| 46 | self.engine = None |
| 47 | |
| 48 | def forward( |
| 49 | self, |
| 50 | x: torch.Tensor, |
| 51 | timesteps: torch.Tensor, |
| 52 | context: torch.Tensor, |
| 53 | *args, |
| 54 | **kwargs, |
| 55 | ) -> torch.Tensor: |
| 56 | nvtx.range_push("forward") |
| 57 | feed_dict = { |
| 58 | "sample": x.float(), |
| 59 | "timesteps": timesteps.float(), |
| 60 | "encoder_hidden_states": context.float(), |
| 61 | } |
| 62 | if "y" in kwargs: |
| 63 | feed_dict["y"] = kwargs["y"].float() |
| 64 | |
| 65 | tmp = torch.empty( |
| 66 | self.engine_vram_req, dtype=torch.uint8, device=devices.device |
| 67 | ) |
| 68 | self.engine.context.device_memory = tmp.data_ptr() |
| 69 | self.cudaStream = torch.cuda.current_stream().cuda_stream |
| 70 | self.engine.allocate_buffers(feed_dict) |
| 71 | |
| 72 | out = self.engine.infer(feed_dict, self.cudaStream)["latent"] |
| 73 | |
| 74 | nvtx.range_pop() |
| 75 | return out |
| 76 | |
| 77 | def apply_loras(self, refit_dict: dict): |
| 78 | if not self.refitted_keys.issubset(set(refit_dict.keys())): |
| 79 | # Need to ensure that weights that have been modified before and are not present anymore are reset. |
| 80 | self.refitted_keys = set() |
| 81 | self.switch_engine() |
| 82 | |
| 83 | self.engine.refit_from_dict(refit_dict, is_fp16=True) |
| 84 | self.refitted_keys = set(refit_dict.keys()) |
| 85 | |
| 86 | def switch_engine(self): |
| 87 | self.loaded_config = self.configs[self.profile_idx] |
| 88 | self.engine.reset(os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"])) |
| 89 | self.activate() |