(dataset)
| 50 | |
| 51 | |
| 52 | def load_data(dataset): |
| 53 | transform = transforms.Compose([ |
| 54 | transforms.Resize(256), |
| 55 | transforms.CenterCrop(224), |
| 56 | transforms.ToTensor(), |
| 57 | transforms.Normalize( |
| 58 | mean=[0.485, 0.456, 0.406], |
| 59 | std=[0.229, 0.224, 0.225] |
| 60 | ) |
| 61 | ]) |
| 62 | if dataset == 'imagenet-r': |
| 63 | imagenet_r = datasets.ImageFolder('imagenet-r', transform=transform) |
| 64 | imagenetr_labels = open('imagenetr_labels.txt').read().splitlines() |
| 65 | imagenetr_labels = [x.split(',')[1].strip() for x in imagenetr_labels] |
| 66 | return imagenet_r, imagenetr_labels |
| 67 | else: |
| 68 | officehome = datasets.ImageFolder(dataset, transform=transform) |
| 69 | officehome_labels = officehome.classes |
| 70 | return officehome, officehome_labels |
| 71 | |
| 72 | |
| 73 | def load_model(modelname): |
no outgoing calls
no test coverage detected