| 13 | #读取文件夹mnist下的42000张图片,图片为灰度图,所以为1通道,图像大小28*28 |
| 14 | #如果是将彩色图作为输入,则将1替换为3,并且data[i,:,:,:] = arr改为data[i,:,:,:] = [arr[:,:,0],arr[:,:,1],arr[:,:,2]] |
| 15 | def load_data(): |
| 16 | data = np.empty((42000,1,28,28),dtype="float32") |
| 17 | label = np.empty((42000,),dtype="uint8") |
| 18 | |
| 19 | imgs = os.listdir("./mnist") |
| 20 | num = len(imgs) |
| 21 | for i in range(num): |
| 22 | img = Image.open("./mnist/"+imgs[i]) |
| 23 | arr = np.asarray(img,dtype="float32") |
| 24 | data[i,:,:,:] = arr |
| 25 | label[i] = int(imgs[i].split('.')[0]) |
| 26 | data /= np.max(data) |
| 27 | data -= np.mean(data) |
| 28 | return data,label |
| 29 | |
| 30 | |
| 31 | |