(
self,
base_model: str,
width: int,
height: int,
batch_size: int,
max_embedding: int,
)
| 211 | return valid_models, distances, idx |
| 212 | |
| 213 | def get_valid_models( |
| 214 | self, |
| 215 | base_model: str, |
| 216 | width: int, |
| 217 | height: int, |
| 218 | batch_size: int, |
| 219 | max_embedding: int, |
| 220 | ): |
| 221 | valid_models = [] |
| 222 | distances = [] |
| 223 | idx = [] |
| 224 | models = self.available_models() |
| 225 | for i, model in enumerate(models[base_model]): |
| 226 | valid, distance = model["config"].is_compatible( |
| 227 | width, height, batch_size, max_embedding |
| 228 | ) |
| 229 | if valid: |
| 230 | valid_models.append(model) |
| 231 | distances.append(distance) |
| 232 | idx.append(i) |
| 233 | |
| 234 | return valid_models, distances, idx |
| 235 | |
| 236 | |
| 237 | modelmanager = ModelManager() |
no test coverage detected