Returns the fetch method from model card - if not found, then loads default.
(self, model_card)
| 1216 | f"retrieve selected model from model repository.")) |
| 1217 | |
| 1218 | def fetch_resolve(self, model_card): |
| 1219 | |
| 1220 | """ Returns the fetch method from model card - if not found, then loads default. """ |
| 1221 | |
| 1222 | # need to fetch the model -> will use fetch method provided in model card |
| 1223 | fetch_module = None |
| 1224 | fetch_method = None |
| 1225 | fetch_class = None |
| 1226 | fetch_exec = None |
| 1227 | |
| 1228 | default_fetch = LLMWareConfig().get_config("model_fetch") |
| 1229 | |
| 1230 | if LLMWareConfig().get_config("apply_default_fetch_override"): |
| 1231 | |
| 1232 | # if set to True, will over-ride the model card and use the default fetch mechanism |
| 1233 | |
| 1234 | fetch_module = default_fetch["module"] |
| 1235 | if "class" in default_fetch: |
| 1236 | fetch_class = default_fetch["class"] |
| 1237 | if "method" in default_fetch: |
| 1238 | fetch_method = default_fetch["method"] |
| 1239 | |
| 1240 | else: |
| 1241 | |
| 1242 | # primary (default) case - each model card provides configs for how to fetch the model |
| 1243 | |
| 1244 | if "fetch" in model_card: |
| 1245 | if "module" in model_card["fetch"]: |
| 1246 | fetch_module = model_card["fetch"]["module"] |
| 1247 | if "method" in model_card["fetch"]: |
| 1248 | fetch_method = model_card["fetch"]["method"] |
| 1249 | if "class" in model_card["fetch"]: |
| 1250 | fetch_class = model_card["fetch"]["class"] |
| 1251 | |
| 1252 | if not fetch_module: |
| 1253 | |
| 1254 | # fallback case - if not provided in model card, then fallback to the default fetch mechanism |
| 1255 | |
| 1256 | fetch_module = default_fetch["module"] |
| 1257 | |
| 1258 | if "class" in default_fetch: |
| 1259 | fetch_class = default_fetch["class"] |
| 1260 | if "method" in default_fetch: |
| 1261 | fetch_method = default_fetch["method"] |
| 1262 | |
| 1263 | module = importlib.import_module(fetch_module) |
| 1264 | |
| 1265 | if fetch_class: |
| 1266 | if hasattr(module, fetch_class): |
| 1267 | class_exec = getattr(module, fetch_class)() |
| 1268 | if hasattr(class_exec, fetch_method): |
| 1269 | fetch_exec = getattr(class_exec,fetch_method) |
| 1270 | else: |
| 1271 | if hasattr(module, fetch_method): |
| 1272 | fetch_exec = getattr(module, fetch_method) |
| 1273 | |
| 1274 | return fetch_exec, fetch_method |
| 1275 |
no test coverage detected