| 499 | |
| 500 | |
| 501 | class CLIPModel: |
| 502 | |
| 503 | def __init__(self, dtype, device, checkpoint_path, tokenizer_path): |
| 504 | self.dtype = dtype |
| 505 | self.device = device |
| 506 | self.checkpoint_path = checkpoint_path |
| 507 | self.tokenizer_path = tokenizer_path |
| 508 | |
| 509 | # init model |
| 510 | self.model, self.transforms = clip_xlm_roberta_vit_h_14( |
| 511 | pretrained=False, |
| 512 | return_transforms=True, |
| 513 | return_tokenizer=False, |
| 514 | dtype=dtype, |
| 515 | device=device) |
| 516 | self.model = self.model.eval().requires_grad_(False) |
| 517 | logging.info(f'loading {checkpoint_path}') |
| 518 | self.model.load_state_dict( |
| 519 | torch.load(checkpoint_path, map_location='cpu')) |
| 520 | |
| 521 | # init tokenizer |
| 522 | self.tokenizer = HuggingfaceTokenizer( |
| 523 | name=tokenizer_path, |
| 524 | seq_len=self.model.max_text_len - 2, |
| 525 | clean='whitespace') |
| 526 | |
| 527 | def visual(self, videos): |
| 528 | # preprocess |
| 529 | size = (self.model.image_size,) * 2 |
| 530 | videos = torch.cat([ |
| 531 | F.interpolate( |
| 532 | u.transpose(0, 1), |
| 533 | size=size, |
| 534 | mode='bicubic', |
| 535 | align_corners=False) for u in videos |
| 536 | ]) |
| 537 | videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5)) |
| 538 | |
| 539 | # forward |
| 540 | with torch.cuda.amp.autocast(dtype=self.dtype): |
| 541 | out = self.model.visual(videos, use_31_block=True) |
| 542 | return out |