| 45 | m.bias.data.zero_() |
| 46 | |
| 47 | class MyDataset(Dataset): |
| 48 | def __init__(self, txt_path, transform = None, target_transform = None): |
| 49 | fh = open(txt_path, 'r') |
| 50 | imgs = [] |
| 51 | for line in fh: |
| 52 | line = line.rstrip() |
| 53 | words = line.split() |
| 54 | imgs.append((words[0], int(words[1]))) |
| 55 | |
| 56 | self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据 |
| 57 | self.transform = transform |
| 58 | self.target_transform = target_transform |
| 59 | |
| 60 | def __getitem__(self, index): |
| 61 | fn, label = self.imgs[index] |
| 62 | img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1 |
| 63 | |
| 64 | if self.transform is not None: |
| 65 | img = self.transform(img) # 在这里做transform,转为tensor等等 |
| 66 | |
| 67 | return img, label |
| 68 | |
| 69 | def __len__(self): |
| 70 | return len(self.imgs) |
| 71 | |
| 72 | |
| 73 | def validate(net, data_loader, set_name, classes_name): |
no outgoing calls
no test coverage detected