(prpt: str)
| 280 | t5xxl = t5xxl.to(device) |
| 281 | |
| 282 | def encode(prpt: str): |
| 283 | tokens_and_masks = tokenize_strategy.tokenize(prpt) |
| 284 | with torch.no_grad(): |
| 285 | if clip_l is not None: |
| 286 | if is_fp8(clip_l_dtype): |
| 287 | with accelerator.autocast(): |
| 288 | l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) |
| 289 | else: |
| 290 | with torch.autocast(device_type=device.type, dtype=clip_l_dtype): |
| 291 | l_pooled, _, _, _ = encoding_strategy.encode_tokens(tokenize_strategy, [clip_l, None], tokens_and_masks) |
| 292 | else: |
| 293 | l_pooled = None |
| 294 | |
| 295 | if is_fp8(t5xxl_dtype): |
| 296 | with accelerator.autocast(): |
| 297 | _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( |
| 298 | tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask |
| 299 | ) |
| 300 | else: |
| 301 | with torch.autocast(device_type=device.type, dtype=t5xxl_dtype): |
| 302 | _, t5_out, txt_ids, t5_attn_mask = encoding_strategy.encode_tokens( |
| 303 | tokenize_strategy, [clip_l, t5xxl], tokens_and_masks, args.apply_t5_attn_mask |
| 304 | ) |
| 305 | return l_pooled, t5_out, txt_ids, t5_attn_mask |
| 306 | |
| 307 | l_pooled, t5_out, txt_ids, t5_attn_mask = encode(prompt) |
| 308 | if negative_prompt: |
no test coverage detected