(model, tokenizer, generation_config, prefix=None, suffix=None, create_pr=None)
| 206 | |
| 207 | |
| 208 | def push_to_hub(model, tokenizer, generation_config, prefix=None, suffix=None, create_pr=None): |
| 209 | if create_pr is None: |
| 210 | create_pr = _parse_args().create_pr |
| 211 | |
| 212 | model_class_name = model.__class__.__name__ |
| 213 | content = MODEL_CARD.format(model_class_name=model_class_name) |
| 214 | model_card = ModelCard(content) |
| 215 | if prefix is not None: |
| 216 | model_class_name = f"{prefix}-{model_class_name}" |
| 217 | repo_id = f"{ORGANIZATION}/{model_class_name}" |
| 218 | if suffix is not None: |
| 219 | repo_id += f"-{suffix}" |
| 220 | |
| 221 | exists = api.repo_exists(repo_id) |
| 222 | if exists and not create_pr: |
| 223 | print(f"Model {repo_id} already exists, skipping (pass --create-pr to open a PR)") |
| 224 | return |
| 225 | |
| 226 | if not exists: |
| 227 | api.create_repo(repo_id, exist_ok=True) |
| 228 | |
| 229 | # Save all artifacts to a temp dir and upload them in a single commit, so --create-pr opens one PR. |
| 230 | with tempfile.TemporaryDirectory() as tmpdir: |
| 231 | model.save_pretrained(tmpdir) |
| 232 | if tokenizer is not None: |
| 233 | tokenizer.save_pretrained(tmpdir) |
| 234 | if generation_config is not None: |
| 235 | generation_config.save_pretrained(tmpdir) |
| 236 | model_card.save(os.path.join(tmpdir, "README.md")) |
| 237 | |
| 238 | operations = [ |
| 239 | CommitOperationAdd( |
| 240 | path_in_repo=os.path.relpath(os.path.join(root, name), tmpdir), |
| 241 | path_or_fileobj=os.path.join(root, name), |
| 242 | ) |
| 243 | for root, _, files in os.walk(tmpdir) |
| 244 | for name in files |
| 245 | ] |
| 246 | commit_info = api.create_commit( |
| 247 | repo_id=repo_id, |
| 248 | operations=operations, |
| 249 | commit_message=f"Upload {model.__class__.__name__}", |
| 250 | create_pr=exists and create_pr, |
| 251 | ) |
| 252 | if commit_info.pr_url: |
| 253 | print(f"[push_to_hub] PR opened: {commit_info.pr_url}") |
| 254 | |
| 255 | |
| 256 | def init_weights_tiny_model(model): |
no test coverage detected
searching dependent graphs…