returned embeddings are not masked
(
self,
tokenize_strategy: TokenizeStrategy,
models: List[Any],
tokens: List[torch.Tensor],
apply_lg_attn_mask: Optional[bool] = False,
apply_t5_attn_mask: Optional[bool] = False,
enable_dropout: bool = True,
)
| 70 | self.t5_dropout_rate = t5_dropout_rate |
| 71 | |
| 72 | def encode_tokens( |
| 73 | self, |
| 74 | tokenize_strategy: TokenizeStrategy, |
| 75 | models: List[Any], |
| 76 | tokens: List[torch.Tensor], |
| 77 | apply_lg_attn_mask: Optional[bool] = False, |
| 78 | apply_t5_attn_mask: Optional[bool] = False, |
| 79 | enable_dropout: bool = True, |
| 80 | ) -> List[torch.Tensor]: |
| 81 | """ |
| 82 | returned embeddings are not masked |
| 83 | """ |
| 84 | clip_l, clip_g, t5xxl = models |
| 85 | clip_l: Optional[CLIPTextModel] |
| 86 | clip_g: Optional[CLIPTextModelWithProjection] |
| 87 | t5xxl: Optional[T5EncoderModel] |
| 88 | |
| 89 | if apply_lg_attn_mask is None: |
| 90 | apply_lg_attn_mask = self.apply_lg_attn_mask |
| 91 | if apply_t5_attn_mask is None: |
| 92 | apply_t5_attn_mask = self.apply_t5_attn_mask |
| 93 | |
| 94 | l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens |
| 95 | |
| 96 | # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings |
| 97 | |
| 98 | if l_tokens is None or clip_l is None: |
| 99 | assert g_tokens is None, "g_tokens must be None if l_tokens is None" |
| 100 | lg_out = None |
| 101 | lg_pooled = None |
| 102 | l_attn_mask = None |
| 103 | g_attn_mask = None |
| 104 | else: |
| 105 | assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None" |
| 106 | |
| 107 | # drop some members of the batch: we do not call clip_l and clip_g for dropped members |
| 108 | batch_size, l_seq_len = l_tokens.shape |
| 109 | g_seq_len = g_tokens.shape[1] |
| 110 | |
| 111 | non_drop_l_indices = [] |
| 112 | non_drop_g_indices = [] |
| 113 | for i in range(l_tokens.shape[0]): |
| 114 | drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate) |
| 115 | drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate) |
| 116 | if not drop_l: |
| 117 | non_drop_l_indices.append(i) |
| 118 | if not drop_g: |
| 119 | non_drop_g_indices.append(i) |
| 120 | |
| 121 | # filter out dropped members |
| 122 | if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size: |
| 123 | l_tokens = l_tokens[non_drop_l_indices] |
| 124 | l_attn_mask = l_attn_mask[non_drop_l_indices] |
| 125 | if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size: |
| 126 | g_tokens = g_tokens[non_drop_g_indices] |
| 127 | g_attn_mask = g_attn_mask[non_drop_g_indices] |
| 128 | |
| 129 | # call clip_l for non-dropped members |
no test coverage detected