Pulls selected model from llmware public S3 repo to local model repo
(self, model_name=None, bucket_name=None, save_path=None)
| 4372 | return True |
| 4373 | |
| 4374 | def fetch_model_from_bucket(self, model_name=None, bucket_name=None, save_path=None): |
| 4375 | |
| 4376 | """Pulls selected model from llmware public S3 repo to local model repo""" |
| 4377 | |
| 4378 | # if no model name selected, then get all |
| 4379 | if not bucket_name: |
| 4380 | bucket_name = LLMWareConfig().get_config("llmware_public_models_bucket") |
| 4381 | |
| 4382 | # check for llmware path & create if not already set up |
| 4383 | if not os.path.exists(LLMWareConfig.get_llmware_path()): |
| 4384 | # if not explicitly set up by user, then create folder directory structure |
| 4385 | LLMWareConfig().setup_llmware_workspace() |
| 4386 | |
| 4387 | if not save_path: |
| 4388 | save_path = LLMWareConfig.get_model_repo_path() |
| 4389 | |
| 4390 | if not os.path.exists(save_path): |
| 4391 | os.makedirs(save_path) |
| 4392 | |
| 4393 | # assumes that files in model folder are something like: |
| 4394 | # "pytorch_model.bin" | "config.json" | "tokenizer.json" |
| 4395 | |
| 4396 | bucket = boto3.resource('s3', config=Config(signature_version=UNSIGNED)).Bucket(bucket_name) |
| 4397 | |
| 4398 | files = bucket.objects.all() |
| 4399 | |
| 4400 | for file in files: |
| 4401 | |
| 4402 | if file.key.startswith(model_name): |
| 4403 | |
| 4404 | # found component of model in repo, so go ahead and create local model folder, if needed |
| 4405 | local_model_folder = os.path.join(save_path, model_name) |
| 4406 | if not os.path.exists(local_model_folder): |
| 4407 | os.mkdir(local_model_folder) |
| 4408 | |
| 4409 | # simple model_repo structure - each model is a separate folder |
| 4410 | # each model is a 'flat list' of files, so safe to split on ("/") to get key name |
| 4411 | if not file.key.endswith('/'): |
| 4412 | local_file_path = os.path.join(local_model_folder, file.key.split('/')[-1]) |
| 4413 | bucket.download_file(file.key, local_file_path) |
| 4414 | |
| 4415 | logger.info(f"update: successfully downloaded model - {model_name} - " |
| 4416 | f"from aws s3 bucket for future access") |
| 4417 | |
| 4418 | return files |
| 4419 | |
| 4420 | |
| 4421 | class ParserState: |
nothing calls this directly
no test coverage detected