Adds a seed to the seed manager for `parallel_mode`. Args: parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. seed (int): The seed to be added. overwrite (bool, optional): Whether allows to overwrite the seed that h
(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False)
| 60 | torch.cuda.set_rng_state(self._seed_states[parallel_mode]) |
| 61 | |
| 62 | def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False): |
| 63 | """Adds a seed to the seed manager for `parallel_mode`. |
| 64 | |
| 65 | Args: |
| 66 | parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. |
| 67 | seed (int): The seed to be added. |
| 68 | overwrite (bool, optional): Whether allows to overwrite the seed that has been set already |
| 69 | |
| 70 | Raises: |
| 71 | AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode` |
| 72 | or the seed for `parallel_mode` has been added. |
| 73 | """ |
| 74 | assert isinstance(parallel_mode, ParallelMode), "A valid ParallelMode must be provided" |
| 75 | if overwrite is False: |
| 76 | assert parallel_mode not in self._seed_states, f"The seed for {parallel_mode} has been added" |
| 77 | elif parallel_mode in self._seed_states: |
| 78 | print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True) |
| 79 | |
| 80 | current_state = torch.cuda.get_rng_state() |
| 81 | torch.cuda.manual_seed(seed) |
| 82 | self._seed_states[parallel_mode] = torch.cuda.get_rng_state() |
| 83 | self._seeds[parallel_mode] = seed |
| 84 | torch.cuda.set_rng_state(current_state) |
| 85 | |
| 86 | def reset(self): |
| 87 | self._current_mode = None |
no test coverage detected