MCPcopy
hub / github.com/SesameAILabs/csm / _create_causal_mask

Function _create_causal_mask

models.py:55–56  ·  view source on GitHub ↗
(seq_len: int, device: torch.device)

Source from the content-addressed store, hash-verified

53
54
55def _create_causal_mask(seq_len: int, device: torch.device):
56 return torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
57
58
59def _index_causal_mask(mask: torch.Tensor, input_pos: torch.Tensor):

Callers 1

setup_cachesMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected