MCPcopy Index your code
hub / github.com/NVIDIA/Stable-Diffusion-WebUI-TensorRT / TrtUnet

Class TrtUnet

scripts/trt.py:32–104  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

30
31
32class TrtUnet(sd_unet.SdUnet):
33 def __init__(self, model_name: str, configs: List[dict], *args, **kwargs):
34 super().__init__(*args, **kwargs)
35
36 self.stream = None
37 self.model_name = model_name
38 self.configs = configs
39
40 self.profile_idx = 0
41 self.loaded_config = None
42
43 self.engine_vram_req = 0
44 self.refitted_keys = set()
45
46 self.engine = None
47
48 def forward(
49 self,
50 x: torch.Tensor,
51 timesteps: torch.Tensor,
52 context: torch.Tensor,
53 *args,
54 **kwargs,
55 ) -> torch.Tensor:
56 nvtx.range_push("forward")
57 feed_dict = {
58 "sample": x.float(),
59 "timesteps": timesteps.float(),
60 "encoder_hidden_states": context.float(),
61 }
62 if "y" in kwargs:
63 feed_dict["y"] = kwargs["y"].float()
64
65 tmp = torch.empty(
66 self.engine_vram_req, dtype=torch.uint8, device=devices.device
67 )
68 self.engine.context.device_memory = tmp.data_ptr()
69 self.cudaStream = torch.cuda.current_stream().cuda_stream
70 self.engine.allocate_buffers(feed_dict)
71
72 out = self.engine.infer(feed_dict, self.cudaStream)["latent"]
73
74 nvtx.range_pop()
75 return out
76
77 def apply_loras(self, refit_dict: dict):
78 if not self.refitted_keys.issubset(set(refit_dict.keys())):
79 # Need to ensure that weights that have been modified before and are not present anymore are reset.
80 self.refitted_keys = set()
81 self.switch_engine()
82
83 self.engine.refit_from_dict(refit_dict, is_fp16=True)
84 self.refitted_keys = set(refit_dict.keys())
85
86 def switch_engine(self):
87 self.loaded_config = self.configs[self.profile_idx]
88 self.engine.reset(os.path.join(TRT_MODEL_DIR, self.loaded_config["filepath"]))
89 self.activate()

Callers 1

create_unetMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected