Pulls down and caches models from public llmware public S3 repo
(self, access_key=None,secret_key=None)
| 4204 | return downloaded_files |
| 4205 | |
| 4206 | def create_local_model_repo(self, access_key=None,secret_key=None): |
| 4207 | |
| 4208 | """Pulls down and caches models from public llmware public S3 repo """ |
| 4209 | |
| 4210 | # list of models retrieved from cloud repo |
| 4211 | models_retrieved = [] |
| 4212 | |
| 4213 | # check for llmware path & create if not already set up |
| 4214 | if not os.path.exists(LLMWareConfig.get_llmware_path()): |
| 4215 | # if not explicitly set up by user, then create folder directory structure |
| 4216 | LLMWareConfig.setup_llmware_workspace() |
| 4217 | |
| 4218 | # confirm that local model repo path has been created |
| 4219 | local_model_repo_path = LLMWareConfig.get_model_repo_path() |
| 4220 | if not os.path.exists(local_model_repo_path): |
| 4221 | os.mkdir(local_model_repo_path) |
| 4222 | |
| 4223 | aws_repo_bucket = LLMWareConfig.get_config("llmware_public_models_bucket") |
| 4224 | |
| 4225 | # if specific model_list passed, then only retrieve models on the list |
| 4226 | |
| 4227 | bucket = boto3.resource('s3', aws_access_key_id=access_key, |
| 4228 | aws_secret_access_key=secret_key).Bucket(aws_repo_bucket) |
| 4229 | |
| 4230 | files = bucket.objects.all() |
| 4231 | |
| 4232 | s3 = boto3.client('s3', aws_access_key_id=access_key, aws_secret_access_key=secret_key) |
| 4233 | |
| 4234 | # bucket = s3.Bucket(bucket_name) |
| 4235 | # files = bucket.objects.all() |
| 4236 | |
| 4237 | for file in files: |
| 4238 | |
| 4239 | name_parts = file.key.split(os.sep) |
| 4240 | |
| 4241 | # confirm that file.key is correctly structure as [0] model name, and [1] model component |
| 4242 | if len(name_parts) == 2: |
| 4243 | |
| 4244 | logger.info(f"update: identified models in model_repo: {name_parts}") |
| 4245 | |
| 4246 | if name_parts[0] and name_parts[1]: |
| 4247 | |
| 4248 | model_folder = os.path.join(local_model_repo_path,name_parts[0]) |
| 4249 | |
| 4250 | if not os.path.exists(model_folder): |
| 4251 | os.mkdir(model_folder) |
| 4252 | models_retrieved.append(name_parts[0]) |
| 4253 | |
| 4254 | logger.info(f"update: downloading file from s3 bucket - " |
| 4255 | f"{name_parts[1]} - {file.key}") |
| 4256 | |
| 4257 | s3.download_file(aws_repo_bucket, file.key, os.path.join(model_folder,name_parts[1])) |
| 4258 | |
| 4259 | logger.info(f"update: created local model repository - {len(models_retrieved)} models retrieved - " |
| 4260 | f" model list - {models_retrieved}") |
| 4261 | |
| 4262 | return models_retrieved |
| 4263 |
nothing calls this directly
no test coverage detected