(self, dataloader, testloader, optimizer, nepochs=10, save_path=None)
| 99 | |
| 100 | |
| 101 | def finetune(self, dataloader, testloader, optimizer, nepochs=10, save_path=None): |
| 102 | loss_img = nn.CrossEntropyLoss() |
| 103 | loss_txt = nn.CrossEntropyLoss() |
| 104 | best_acc = 0 |
| 105 | for epoch in range(nepochs): |
| 106 | total_loss = 0 |
| 107 | for batch in tqdm(dataloader): |
| 108 | optimizer.zero_grad() |
| 109 | image, text, _ = batch |
| 110 | image = image.to(self.device) |
| 111 | text = text.to(self.device) |
| 112 | logits_per_image, logits_per_text = self.model(image, text) |
| 113 | |
| 114 | ground_truth = torch.arange( |
| 115 | len(image), dtype=torch.long, device=self.device) |
| 116 | |
| 117 | loss = (loss_img(logits_per_image, ground_truth) + |
| 118 | loss_txt(logits_per_text, ground_truth))/2 |
| 119 | loss.backward() |
| 120 | total_loss += loss.item() |
| 121 | if self.device == "cpu": |
| 122 | optimizer.step() |
| 123 | else: |
| 124 | convert_models_to_fp32(self.model) |
| 125 | optimizer.step() |
| 126 | clip.model.convert_weights(self.model) |
| 127 | |
| 128 | eval_acc, _ = self.evaluate(testloader) |
| 129 | if eval_acc > best_acc: |
| 130 | best_acc = eval_acc |
| 131 | if save_path is not None: |
| 132 | torch.save(self.model.state_dict(), save_path) |
| 133 | self.logger.info("Epoch {} : Loss {}, Acc {:.4f}".format( |
| 134 | epoch, total_loss/len(dataloader), eval_acc)) |
| 135 | return best_acc |
| 136 | |
| 137 | |
| 138 | def evaluate(self, dataloader, modelpath=None): |
no test coverage detected