MCPcopy Index your code
hub / github.com/kyegomez/BitNet / AutoregressiveWrapper

Class AutoregressiveWrapper

bitnet/at.py:35–107  ·  view source on GitHub ↗

AutoregressiveWrapper is a wrapper class that adds autoregressive generation functionality to a given neural network. Args: net (nn.Module): The neural network model. max_seq_len (int): The maximum sequence length for generation. Defaults to 2048. pad_value (int): T

Source from the content-addressed store, hash-verified

33
34
35class AutoregressiveWrapper(nn.Module):
36 """
37 AutoregressiveWrapper is a wrapper class that adds autoregressive generation functionality to a given neural network.
38
39 Args:
40 net (nn.Module): The neural network model.
41 max_seq_len (int): The maximum sequence length for generation. Defaults to 2048.
42 pad_value (int): The padding value for generated sequences. Defaults to 0.
43 """
44
45 def __init__(self, net, max_seq_len=2048, pad_value=0):
46 super().__init__()
47 self.max_seq_len = max_seq_len
48 self.pad_value = pad_value
49 self.net = net
50
51 @torch.no_grad()
52 @eval_decorator
53 def generate(
54 self,
55 start_tokens,
56 seq_len,
57 eos_token=None,
58 temperature=1.0,
59 filter_thres=0.9,
60 **kwargs,
61 ):
62 """
63 Generates autoregressive sequences based on the given start tokens.
64
65 Args:
66 start_tokens (torch.Tensor): The initial tokens to start the generation.
67 seq_len (int): The length of the generated sequence.
68 eos_token (int, optional): The end-of-sequence token. If provided, generation will stop when this token is generated. Defaults to None.
69 temperature (float, optional): The temperature value for controlling the randomness of the generation. Higher values result in more randomness. Defaults to 1.0.
70 filter_thres (float, optional): The threshold value for filtering logits during generation. Only logits above this threshold will be considered. Defaults to 0.9.
71 **kwargs: Additional keyword arguments to be passed to the underlying network.
72
73 Returns:
74 torch.Tensor: The generated sequence.
75 """
76
77 b, t, device = *start_tokens.shape, start_tokens.device
78
79 out = start_tokens
80
81 for _ in range(seq_len):
82 logits = self.net(out, **kwargs)[:, -1, :]
83
84 filtered_logits = top_k(logits, thres=filter_thres)
85 probs = F.softmax(filtered_logits / temperature, dim=-1)
86
87 sample = torch.multinomial(probs, 1)
88
89 out = torch.cat((out, sample), dim=-1)
90
91 if exists(eos_token):
92 is_eos_token = out == eos_token

Callers 2

train.pyFile · 0.90
__init__Method · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected