| 49 | |
| 50 | @dataclass |
| 51 | class ModelConfig: |
| 52 | profile: dict |
| 53 | static_shapes: bool |
| 54 | fp32: bool |
| 55 | inpaint: bool |
| 56 | refit: bool |
| 57 | lora: bool |
| 58 | vram: int |
| 59 | unet_hidden_dim: int = 4 |
| 60 | |
| 61 | def is_compatible_from_dict(self, feed_dict: dict): |
| 62 | distance = 0 |
| 63 | for k, v in feed_dict.items(): |
| 64 | _min, _opt, _max = self.profile[k] |
| 65 | v_tensor = torch.Tensor(list(v.shape)) |
| 66 | r_min = torch.Tensor(_max) - v_tensor |
| 67 | r_opt = (torch.Tensor(_opt) - v_tensor).abs() |
| 68 | r_max = v_tensor - torch.Tensor(_min) |
| 69 | if torch.any(r_min < 0) or torch.any(r_max < 0): |
| 70 | return (False, distance) |
| 71 | distance += r_opt.sum() + 0.5 * (r_max.sum() + 0.5 * r_min.sum()) |
| 72 | return (True, distance) |
| 73 | |
| 74 | def is_compatible( |
| 75 | self, width: int, height: int, batch_size: int, max_embedding: int |
| 76 | ): |
| 77 | distance = 0 |
| 78 | sample = self.profile["sample"] |
| 79 | embedding = self.profile["encoder_hidden_states"] |
| 80 | |
| 81 | batch_size *= 2 |
| 82 | width = width // 8 |
| 83 | height = height // 8 |
| 84 | |
| 85 | _min, _opt, _max = sample |
| 86 | if _min[0] > batch_size or _max[0] < batch_size: |
| 87 | return (False, distance) |
| 88 | if _min[2] > height or _max[2] < height: |
| 89 | return (False, distance) |
| 90 | if _min[3] > width or _max[3] < width: |
| 91 | return (False, distance) |
| 92 | |
| 93 | _min_em, _opt_em, _max_em = embedding |
| 94 | if _min_em[1] > max_embedding or _max_em[1] < max_embedding: |
| 95 | return (False, distance) |
| 96 | |
| 97 | distance = ( |
| 98 | abs(_opt[0] - batch_size) |
| 99 | + abs(_opt[2] - height) |
| 100 | + abs(_opt[3] - width) |
| 101 | + 0.5 * (abs(_max[2] - height) + abs(_max[3] - width)) |
| 102 | ) |
| 103 | |
| 104 | return (True, distance) |
| 105 | |
| 106 | |
| 107 | class ModelConfigEncoder(JSONEncoder): |
no outgoing calls
no test coverage detected