MCPcopy
hub / github.com/SesameAILabs/csm / _index_causal_mask

Function _index_causal_mask

models.py:59–69  ·  view source on GitHub ↗

Args: mask: (max_seq_len, max_seq_len) input_pos: (batch_size, seq_len) Returns: (batch_size, seq_len, max_seq_len)

(mask: torch.Tensor, input_pos: torch.Tensor)

Source from the content-addressed store, hash-verified

57
58
59def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):
60 """
61 Args:
62 mask: (max_seq_len, max_seq_len)
63 input_pos: (batch_size, seq_len)
64
65 Returns:
66 (batch_size, seq_len, max_seq_len)
67 """
68 r = mask[input_pos, :]
69 return r
70
71
72def _multinomial_sample_one_no_sync(probs): # Does multinomial sampling without a cuda synchronization

Callers 1

generate_frameMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected