MCPcopy
hub / github.com/hkust-nlp/simpleRL-reason / create_random_mask

Function create_random_mask

verl/utils/model.py:153–191  ·  view source on GitHub ↗

Create a random mask given input_ids. Support left padding and right padding. Process: - Sample valid token length - Sample left_padding length - Generate padding Args: input_ids: shape (batch_size, seq_len) Returns:

(input_ids: torch.Tensor,
                       max_ratio_of_valid_token: float,
                       max_ratio_of_left_padding: float,
                       min_ratio_of_valid_token: float = 0)

Source from the content-addressed store, hash-verified

151
152
153def create_random_mask(input_ids: torch.Tensor,
154 max_ratio_of_valid_token: float,
155 max_ratio_of_left_padding: float,
156 min_ratio_of_valid_token: float = 0):
157 """Create a random mask given input_ids. Support left padding and right padding.
158 Process:
159 - Sample valid token length
160 - Sample left_padding length
161 - Generate padding
162
163 Args:
164 input_ids:
165 shape (batch_size, seq_len)
166
167 Returns:
168
169 """
170 assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1.
171 assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1.
172 assert min_ratio_of_valid_token <= max_ratio_of_valid_token
173
174 batch_size, sequence_length = input_ids.shape
175 max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token)
176 min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token))
177 max_left_padding = int(sequence_length * max_ratio_of_left_padding)
178 assert max_num_valid_tokens + max_left_padding <= sequence_length
179 assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length
180 masks = torch.ones_like(input_ids, dtype=torch.int64)
181 # TODO: we can make this faster
182 for i in range(batch_size):
183 num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64)
184 num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64)
185
186 for index in range(num_left_padding):
187 masks[i, index] = 0
188
189 for index in range(num_left_padding + num_valid, sequence_length):
190 masks[i, index] = 0
191 return masks
192
193
194def compute_position_id_with_mask(mask):

Callers 6

test_seqlen_balancingFunction · 0.90
test_hf_casual_modelsFunction · 0.90
test_hf_value_modelsFunction · 0.90
test_hf_casual_fwdFunction · 0.90
test_hf_casual_fwd_bwdFunction · 0.90

Calls

no outgoing calls

Tested by 6

test_seqlen_balancingFunction · 0.72
test_hf_casual_modelsFunction · 0.72
test_hf_value_modelsFunction · 0.72
test_hf_casual_fwdFunction · 0.72
test_hf_casual_fwd_bwdFunction · 0.72