Returns the tokenized representation of given input string(s) Parameters ---------- texts : Union[str, List[str]] An input string or a list of input strings to tokenize context_length : int The context length to use; all CLIP models use 77 as the context length
(
texts: Union[str, List[str]], context_length: int = 77
)
| 165 | |
| 166 | |
| 167 | def tokenize( |
| 168 | texts: Union[str, List[str]], context_length: int = 77 |
| 169 | ) -> torch.LongTensor: |
| 170 | """ |
| 171 | Returns the tokenized representation of given input string(s) |
| 172 | |
| 173 | Parameters |
| 174 | ---------- |
| 175 | texts : Union[str, List[str]] |
| 176 | An input string or a list of input strings to tokenize |
| 177 | context_length : int |
| 178 | The context length to use; all CLIP models use 77 as the context length |
| 179 | |
| 180 | Returns |
| 181 | ------- |
| 182 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] |
| 183 | """ |
| 184 | if isinstance(texts, str): |
| 185 | texts = [texts] |
| 186 | |
| 187 | sot_token = _tokenizer.encoder["<start_of_text>"] |
| 188 | eot_token = _tokenizer.encoder["<end_of_text>"] |
| 189 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] |
| 190 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
| 191 | |
| 192 | for i, tokens in enumerate(all_tokens): |
| 193 | if len(tokens) > context_length: |
| 194 | tokens = tokens[:context_length] # Truncate |
| 195 | result[i, : len(tokens)] = torch.tensor(tokens) |
| 196 | |
| 197 | return result |
no test coverage detected