Create a random mask given input_ids. Support left padding and right padding. Process: - Sample valid token length - Sample left_padding length - Generate padding Args: input_ids: shape (batch_size, seq_len) Returns:
(input_ids: torch.Tensor,
max_ratio_of_valid_token: float,
max_ratio_of_left_padding: float,
min_ratio_of_valid_token: float = 0)
| 151 | |
| 152 | |
| 153 | def create_random_mask(input_ids: torch.Tensor, |
| 154 | max_ratio_of_valid_token: float, |
| 155 | max_ratio_of_left_padding: float, |
| 156 | min_ratio_of_valid_token: float = 0): |
| 157 | """Create a random mask given input_ids. Support left padding and right padding. |
| 158 | Process: |
| 159 | - Sample valid token length |
| 160 | - Sample left_padding length |
| 161 | - Generate padding |
| 162 | |
| 163 | Args: |
| 164 | input_ids: |
| 165 | shape (batch_size, seq_len) |
| 166 | |
| 167 | Returns: |
| 168 | |
| 169 | """ |
| 170 | assert max_ratio_of_valid_token > 0 and max_ratio_of_valid_token <= 1. |
| 171 | assert max_ratio_of_left_padding >= 0 and max_ratio_of_left_padding < 1. |
| 172 | assert min_ratio_of_valid_token <= max_ratio_of_valid_token |
| 173 | |
| 174 | batch_size, sequence_length = input_ids.shape |
| 175 | max_num_valid_tokens = int(sequence_length * max_ratio_of_valid_token) |
| 176 | min_num_valid_tokens = max(1, int(sequence_length * min_ratio_of_valid_token)) |
| 177 | max_left_padding = int(sequence_length * max_ratio_of_left_padding) |
| 178 | assert max_num_valid_tokens + max_left_padding <= sequence_length |
| 179 | assert max_num_valid_tokens > 0 and max_ratio_of_valid_token <= sequence_length |
| 180 | masks = torch.ones_like(input_ids, dtype=torch.int64) |
| 181 | # TODO: we can make this faster |
| 182 | for i in range(batch_size): |
| 183 | num_left_padding = np.random.randint(low=0, high=max_left_padding + 1, dtype=np.int64) |
| 184 | num_valid = np.random.randint(low=min_num_valid_tokens, high=max_num_valid_tokens + 1, dtype=np.int64) |
| 185 | |
| 186 | for index in range(num_left_padding): |
| 187 | masks[i, index] = 0 |
| 188 | |
| 189 | for index in range(num_left_padding + num_valid, sequence_length): |
| 190 | masks[i, index] = 0 |
| 191 | return masks |
| 192 | |
| 193 | |
| 194 | def compute_position_id_with_mask(mask): |
no outgoing calls