AutoregressiveWrapper is a wrapper class that adds autoregressive generation functionality to a given neural network. Args: net (nn.Module): The neural network model. max_seq_len (int): The maximum sequence length for generation. Defaults to 2048. pad_value (int): T
| 33 | |
| 34 | |
| 35 | class AutoregressiveWrapper(nn.Module): |
| 36 | """ |
| 37 | AutoregressiveWrapper is a wrapper class that adds autoregressive generation functionality to a given neural network. |
| 38 | |
| 39 | Args: |
| 40 | net (nn.Module): The neural network model. |
| 41 | max_seq_len (int): The maximum sequence length for generation. Defaults to 2048. |
| 42 | pad_value (int): The padding value for generated sequences. Defaults to 0. |
| 43 | """ |
| 44 | |
| 45 | def __init__(self, net, max_seq_len=2048, pad_value=0): |
| 46 | super().__init__() |
| 47 | self.max_seq_len = max_seq_len |
| 48 | self.pad_value = pad_value |
| 49 | self.net = net |
| 50 | |
| 51 | @torch.no_grad() |
| 52 | @eval_decorator |
| 53 | def generate( |
| 54 | self, |
| 55 | start_tokens, |
| 56 | seq_len, |
| 57 | eos_token=None, |
| 58 | temperature=1.0, |
| 59 | filter_thres=0.9, |
| 60 | **kwargs, |
| 61 | ): |
| 62 | """ |
| 63 | Generates autoregressive sequences based on the given start tokens. |
| 64 | |
| 65 | Args: |
| 66 | start_tokens (torch.Tensor): The initial tokens to start the generation. |
| 67 | seq_len (int): The length of the generated sequence. |
| 68 | eos_token (int, optional): The end-of-sequence token. If provided, generation will stop when this token is generated. Defaults to None. |
| 69 | temperature (float, optional): The temperature value for controlling the randomness of the generation. Higher values result in more randomness. Defaults to 1.0. |
| 70 | filter_thres (float, optional): The threshold value for filtering logits during generation. Only logits above this threshold will be considered. Defaults to 0.9. |
| 71 | **kwargs: Additional keyword arguments to be passed to the underlying network. |
| 72 | |
| 73 | Returns: |
| 74 | torch.Tensor: The generated sequence. |
| 75 | """ |
| 76 | |
| 77 | b, t, device = *start_tokens.shape, start_tokens.device |
| 78 | |
| 79 | out = start_tokens |
| 80 | |
| 81 | for _ in range(seq_len): |
| 82 | logits = self.net(out, **kwargs)[:, -1, :] |
| 83 | |
| 84 | filtered_logits = top_k(logits, thres=filter_thres) |
| 85 | probs = F.softmax(filtered_logits / temperature, dim=-1) |
| 86 | |
| 87 | sample = torch.multinomial(probs, 1) |
| 88 | |
| 89 | out = torch.cat((out, sample), dim=-1) |
| 90 | |
| 91 | if exists(eos_token): |
| 92 | is_eos_token = out == eos_token |