MCPcopy Index your code
hub / github.com/hzwer/ECCV2022-RIFE / VimeoDataset

Class VimeoDataset

dataset.py:11–109  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

9cv2.setNumThreads(1)
10device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11class VimeoDataset(Dataset):
12 def __init__(self, dataset_name, batch_size=32):
13 self.batch_size = batch_size
14 self.dataset_name = dataset_name
15 self.h = 256
16 self.w = 448
17 self.data_root = 'vimeo_triplet'
18 self.image_root = os.path.join(self.data_root, 'sequences')
19 train_fn = os.path.join(self.data_root, 'tri_trainlist.txt')
20 test_fn = os.path.join(self.data_root, 'tri_testlist.txt')
21 with open(train_fn, 'r') as f:
22 self.trainlist = f.read().splitlines()
23 with open(test_fn, 'r') as f:
24 self.testlist = f.read().splitlines()
25 self.load_data()
26
27 def __len__(self):
28 return len(self.meta_data)
29
30 def load_data(self):
31 cnt = int(len(self.trainlist) * 0.95)
32 if self.dataset_name == 'train':
33 self.meta_data = self.trainlist[:cnt]
34 elif self.dataset_name == 'test':
35 self.meta_data = self.testlist
36 else:
37 self.meta_data = self.trainlist[cnt:]
38
39 def crop(self, img0, gt, img1, h, w):
40 ih, iw, _ = img0.shape
41 x = np.random.randint(0, ih - h + 1)
42 y = np.random.randint(0, iw - w + 1)
43 img0 = img0[x:x+h, y:y+w, :]
44 img1 = img1[x:x+h, y:y+w, :]
45 gt = gt[x:x+h, y:y+w, :]
46 return img0, gt, img1
47
48 def getimg(self, index):
49 imgpath = os.path.join(self.image_root, self.meta_data[index])
50 imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png']
51
52 # Load images
53 img0 = cv2.imread(imgpaths[0])
54 gt = cv2.imread(imgpaths[1])
55 img1 = cv2.imread(imgpaths[2])
56 timestep = 0.5
57 return img0, gt, img1, timestep
58
59 # RIFEm with Vimeo-Septuplet
60 # imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png']
61 # ind = [0, 1, 2, 3, 4, 5, 6]
62 # random.shuffle(ind)
63 # ind = ind[:3]
64 # ind.sort()
65 # img0 = cv2.imread(imgpaths[ind[0]])
66 # gt = cv2.imread(imgpaths[ind[1]])
67 # img1 = cv2.imread(imgpaths[ind[2]])
68 # timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6)

Callers 1

trainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected