MCPcopy
hub / github.com/policy-gradient/GRPO-Zero / init_kv_cache

Method init_kv_cache

qwen2_model.py:103–114  ·  view source on GitHub ↗
(
        self,
        max_batch_size: int,
        max_seq_len: int,
        dtype: torch.dtype,
        device: torch.device,
    )

Source from the content-addressed store, hash-verified

101 self.args = args
102
103 def init_kv_cache(
104 self,
105 max_batch_size: int,
106 max_seq_len: int,
107 dtype: torch.dtype,
108 device: torch.device,
109 ):
110 cache_shape = (max_batch_size, max_seq_len, self.n_kv_heads, self.head_dim)
111 cache_k = torch.zeros(cache_shape, dtype=dtype, device=device)
112 cache_v = torch.zeros(cache_shape, dtype=dtype, device=device)
113 self.register_buffer("cache_k", cache_k, persistent=False)
114 self.register_buffer("cache_v", cache_v, persistent=False)
115
116 def del_kv_cache(self):
117 self.cache_k = None

Callers 2

rolloutFunction · 0.45
init_kv_cacheMethod · 0.45

Calls

no outgoing calls

Tested by

no test coverage detected