Function
extract_context_feature
(
hidden_states: list[torch.Tensor],
layer_ids: Optional[list[int]],
)
Source from the content-addressed store, hash-verified
| 37 | |
| 38 | |
| 39 | def extract_context_feature( |
| 40 | hidden_states: list[torch.Tensor], |
| 41 | layer_ids: Optional[list[int]], |
| 42 | ) -> torch.Tensor: |
| 43 | offset = 1 |
| 44 | selected_states = [hidden_states[layer_id + offset] for layer_id in layer_ids] |
| 45 | return torch.cat(selected_states, dim=-1) |
| 46 | |
| 47 | |
| 48 | def sample(logits: torch.Tensor, temperature: float = 0.0) -> torch.Tensor: |
Tested by
no test coverage detected