| 259 | self.dense = torch.nn.Linear(self.hidden_size, self.hidden_size) |
| 260 | |
| 261 | def forward( |
| 262 | self, |
| 263 | hidden_states, |
| 264 | query_hidden_state, |
| 265 | attention_mask, |
| 266 | layer_past=None, |
| 267 | get_key_value=False, |
| 268 | prompt_length=None, |
| 269 | context_length=None, |
| 270 | ): |
| 271 | |
| 272 | # hidden_states: [sq, b, h] |
| 273 | query_layer = self.query(query_hidden_state) |
| 274 | key_layer = self.key(hidden_states) |
| 275 | value_layer = self.value(hidden_states) |
| 276 | |
| 277 | new_query_layer_shape = query_layer.size()[:-1] + \ |
| 278 | (self.num_attention_heads, |
| 279 | self.hidden_size_per_attention_head) |
| 280 | query_layer = query_layer.view(*new_query_layer_shape) |
| 281 | |
| 282 | new_query_layer_shape = key_layer.size()[:-1] + \ |
| 283 | (self.num_attention_heads, |
| 284 | self.hidden_size_per_attention_head) |
| 285 | key_layer = key_layer.view(*new_query_layer_shape) |
| 286 | |
| 287 | new_query_layer_shape = value_layer.size()[:-1] + \ |
| 288 | (self.num_attention_heads, |
| 289 | self.hidden_size_per_attention_head) |
| 290 | value_layer = value_layer.view(*new_query_layer_shape) |
| 291 | |
| 292 | # ================================== |
| 293 | # Adjust key and value for inference |
| 294 | # ================================== |
| 295 | |
| 296 | if layer_past is not None: |
| 297 | past_key, past_value = layer_past |
| 298 | key_layer = torch.cat((past_key.type_as(key_layer), |
| 299 | key_layer), dim=0) |
| 300 | value_layer = torch.cat((past_value.type_as(value_layer), |
| 301 | value_layer), dim=0) |
| 302 | if get_key_value: |
| 303 | present = (key_layer, value_layer) |
| 304 | |
| 305 | # =================================== |
| 306 | # Raw attention scores. [b, np, sq, sk] |
| 307 | # =================================== |
| 308 | |
| 309 | # [b, np, sq, sk] |
| 310 | output_size = (query_layer.size(1), |
| 311 | query_layer.size(2), |
| 312 | query_layer.size(0), |
| 313 | key_layer.size(0)) |
| 314 | |
| 315 | # [s, b, np, hn] -> [s, b * np, hn] |
| 316 | query_layer = query_layer.contiguous().view(output_size[2], output_size[0] * output_size[1], -1) |
| 317 | key_layer = key_layer.contiguous().view(output_size[3], output_size[0] * output_size[1], -1) |
| 318 | |