( # noqa: C901
repo,
git_remote: str,
exp_names: Optional[Union[Iterable[str], str]] = None,
all_commits=False,
rev: Optional[Union[list[str], str]] = None,
num=1,
force: bool = False,
pull_cache: bool = False,
**kwargs,
)
| 20 | @locked |
| 21 | @scm_context |
| 22 | def pull( # noqa: C901 |
| 23 | repo, |
| 24 | git_remote: str, |
| 25 | exp_names: Optional[Union[Iterable[str], str]] = None, |
| 26 | all_commits=False, |
| 27 | rev: Optional[Union[list[str], str]] = None, |
| 28 | num=1, |
| 29 | force: bool = False, |
| 30 | pull_cache: bool = False, |
| 31 | **kwargs, |
| 32 | ) -> Iterable[str]: |
| 33 | exp_ref_set: set[ExpRefInfo] = set() |
| 34 | if all_commits: |
| 35 | exp_ref_set.update(exp_refs(repo.scm, git_remote)) |
| 36 | elif exp_names: |
| 37 | if isinstance(exp_names, str): |
| 38 | exp_names = [exp_names] |
| 39 | exp_ref_dict = resolve_name(repo.scm, exp_names, git_remote) |
| 40 | |
| 41 | unresolved_exp_names = [] |
| 42 | for exp_name, exp_ref in exp_ref_dict.items(): |
| 43 | if exp_ref is None: |
| 44 | unresolved_exp_names.append(exp_name) |
| 45 | else: |
| 46 | exp_ref_set.add(exp_ref) |
| 47 | |
| 48 | if unresolved_exp_names: |
| 49 | raise UnresolvedExpNamesError(unresolved_exp_names) |
| 50 | |
| 51 | else: |
| 52 | rev = rev or "HEAD" |
| 53 | if isinstance(rev, str): |
| 54 | rev = [rev] |
| 55 | rev_dict = iter_revs(repo.scm, rev, num) |
| 56 | rev_set = set(rev_dict.keys()) |
| 57 | ref_info_dict = exp_refs_by_baseline(repo.scm, rev_set, git_remote) |
| 58 | for ref_info_list in ref_info_dict.values(): |
| 59 | exp_ref_set.update(ref_info_list) |
| 60 | |
| 61 | pull_result = _pull(repo, git_remote, exp_ref_set, force) |
| 62 | |
| 63 | if pull_result[SyncStatus.DIVERGED]: |
| 64 | diverged_refs = [ref.name for ref in pull_result[SyncStatus.DIVERGED]] |
| 65 | ui.warn( |
| 66 | f"Local experiment '{diverged_refs}' has diverged from remote " |
| 67 | "experiment with the same name. To override the local experiment " |
| 68 | "re-run with '--force'." |
| 69 | ) |
| 70 | |
| 71 | if pull_cache: |
| 72 | pull_cache_ref = ( |
| 73 | pull_result[SyncStatus.UP_TO_DATE] + pull_result[SyncStatus.SUCCESS] |
| 74 | ) |
| 75 | _pull_cache(repo, pull_cache_ref, **kwargs) |
| 76 | |
| 77 | return [ref.name for ref in pull_result[SyncStatus.SUCCESS]] |
| 78 | |
| 79 |
no test coverage detected