(bwa: str = None)
| 256 | |
| 257 | |
| 258 | def get_beam_width_array(bwa: str = None): |
| 259 | bwa = ast.literal_eval(bwa) # Short for "beam_width_array" |
| 260 | if isinstance(bwa, str): |
| 261 | bwa = ast.literal_eval(bwa) # parse again for string |
| 262 | |
| 263 | def parse_one_bwa(row): |
| 264 | assert isinstance(row, list), f"Beam width array must be a list." |
| 265 | assert len( |
| 266 | row |
| 267 | ) <= 8, "Length of beam width array must not be greater than 8 now." |
| 268 | assert all([isinstance(beam, int) for beam in row |
| 269 | ]), "Numbers in beam width array must be integer." |
| 270 | bwa_tensor = torch.zeros([8], dtype=torch.int32) |
| 271 | for j in range(len(row)): |
| 272 | bwa_tensor[j] = row[j] |
| 273 | bwa_tensor[len(row):] = row[-1] |
| 274 | return bwa_tensor, max(row) |
| 275 | |
| 276 | if isinstance(bwa, list): # Only one BWA |
| 277 | bwa_tensor, max_beam_width = parse_one_bwa(bwa) |
| 278 | elif isinstance(bwa, tuple): # BWA for respective requests |
| 279 | bwa_tensor_list = [] |
| 280 | max_beam_width = 0 |
| 281 | for row in bwa: |
| 282 | bwa_tensor, beam_width = parse_one_bwa(row) |
| 283 | bwa_tensor_list.append(bwa_tensor) |
| 284 | max_beam_width = max(max_beam_width, beam_width) |
| 285 | bwa_tensor = torch.stack(bwa_tensor_list, dim=0) |
| 286 | else: |
| 287 | raise ValueError(f"Invalid beam width array: {bwa}") |
| 288 | |
| 289 | return bwa_tensor.tolist(), max_beam_width |
| 290 | |
| 291 | |
| 292 | def add_common_args(parser): |
no test coverage detected