MCPcopy
hub / github.com/kohya-ss/sd-scripts / encode_tokens

Method encode_tokens

library/strategy_sd3.py:72–212  ·  view source on GitHub ↗

returned embeddings are not masked

(
        self,
        tokenize_strategy: TokenizeStrategy,
        models: List[Any],
        tokens: List[torch.Tensor],
        apply_lg_attn_mask: Optional[bool] = False,
        apply_t5_attn_mask: Optional[bool] = False,
        enable_dropout: bool = True,
    )

Source from the content-addressed store, hash-verified

70 self.t5_dropout_rate = t5_dropout_rate
71
72 def encode_tokens(
73 self,
74 tokenize_strategy: TokenizeStrategy,
75 models: List[Any],
76 tokens: List[torch.Tensor],
77 apply_lg_attn_mask: Optional[bool] = False,
78 apply_t5_attn_mask: Optional[bool] = False,
79 enable_dropout: bool = True,
80 ) -> List[torch.Tensor]:
81 """
82 returned embeddings are not masked
83 """
84 clip_l, clip_g, t5xxl = models
85 clip_l: Optional[CLIPTextModel]
86 clip_g: Optional[CLIPTextModelWithProjection]
87 t5xxl: Optional[T5EncoderModel]
88
89 if apply_lg_attn_mask is None:
90 apply_lg_attn_mask = self.apply_lg_attn_mask
91 if apply_t5_attn_mask is None:
92 apply_t5_attn_mask = self.apply_t5_attn_mask
93
94 l_tokens, g_tokens, t5_tokens, l_attn_mask, g_attn_mask, t5_attn_mask = tokens
95
96 # dropout: if enable_dropout is False, dropout is not applied. dropout means zeroing out embeddings
97
98 if l_tokens is None or clip_l is None:
99 assert g_tokens is None, "g_tokens must be None if l_tokens is None"
100 lg_out = None
101 lg_pooled = None
102 l_attn_mask = None
103 g_attn_mask = None
104 else:
105 assert g_tokens is not None, "g_tokens must not be None if l_tokens is not None"
106
107 # drop some members of the batch: we do not call clip_l and clip_g for dropped members
108 batch_size, l_seq_len = l_tokens.shape
109 g_seq_len = g_tokens.shape[1]
110
111 non_drop_l_indices = []
112 non_drop_g_indices = []
113 for i in range(l_tokens.shape[0]):
114 drop_l = enable_dropout and (self.l_dropout_rate > 0.0 and random.random() < self.l_dropout_rate)
115 drop_g = enable_dropout and (self.g_dropout_rate > 0.0 and random.random() < self.g_dropout_rate)
116 if not drop_l:
117 non_drop_l_indices.append(i)
118 if not drop_g:
119 non_drop_g_indices.append(i)
120
121 # filter out dropped members
122 if len(non_drop_l_indices) > 0 and len(non_drop_l_indices) < batch_size:
123 l_tokens = l_tokens[non_drop_l_indices]
124 l_attn_mask = l_attn_mask[non_drop_l_indices]
125 if len(non_drop_g_indices) > 0 and len(non_drop_g_indices) < batch_size:
126 g_tokens = g_tokens[non_drop_g_indices]
127 g_attn_mask = g_attn_mask[non_drop_g_indices]
128
129 # call clip_l for non-dropped members

Callers 2

trainFunction · 0.95
cache_batch_outputsMethod · 0.45

Calls 1

toMethod · 0.80

Tested by

no test coverage detected