MCPcopy
hub / github.com/hpcaitech/ColossalAI / add_seed

Method add_seed

colossalai/legacy/context/random/seed_manager.py:62–84  ·  view source on GitHub ↗

Adds a seed to the seed manager for `parallel_mode`. Args: parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode. seed (int): The seed to be added. overwrite (bool, optional): Whether allows to overwrite the seed that h

(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False)

Source from the content-addressed store, hash-verified

60 torch.cuda.set_rng_state(self._seed_states[parallel_mode])
61
62 def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
63 """Adds a seed to the seed manager for `parallel_mode`.
64
65 Args:
66 parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
67 seed (int): The seed to be added.
68 overwrite (bool, optional): Whether allows to overwrite the seed that has been set already
69
70 Raises:
71 AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of :class:`colossalai.legacy.context.ParallelMode`
72 or the seed for `parallel_mode` has been added.
73 """
74 assert isinstance(parallel_mode, ParallelMode), "A valid ParallelMode must be provided"
75 if overwrite is False:
76 assert parallel_mode not in self._seed_states, f"The seed for {parallel_mode} has been added"
77 elif parallel_mode in self._seed_states:
78 print(f"Warning: {parallel_mode} seed has been overwritten.", flush=True)
79
80 current_state = torch.cuda.get_rng_state()
81 torch.cuda.manual_seed(seed)
82 self._seed_states[parallel_mode] = torch.cuda.get_rng_state()
83 self._seeds[parallel_mode] = seed
84 torch.cuda.set_rng_state(current_state)
85
86 def reset(self):
87 self._current_mode = None

Callers 1

add_seedFunction · 0.80

Calls 3

get_rng_stateMethod · 0.45
manual_seedMethod · 0.45
set_rng_stateMethod · 0.45

Tested by

no test coverage detected