Get data loader: Train, Validate, Test
(dataset_name, collate_name, conf)
| 49 | |
| 50 | |
| 51 | def get_data_loader(dataset_name, collate_name, conf): |
| 52 | """Get data loader: Train, Validate, Test |
| 53 | """ |
| 54 | train_dataset = globals()[dataset_name]( |
| 55 | conf, conf.data.train_json_files, generate_dict=True) |
| 56 | collate_fn = globals()[collate_name](conf, len(train_dataset.label_map)) |
| 57 | |
| 58 | train_data_loader = DataLoader( |
| 59 | train_dataset, batch_size=conf.train.batch_size, shuffle=True, |
| 60 | num_workers=conf.data.num_worker, collate_fn=collate_fn, |
| 61 | pin_memory=True) |
| 62 | |
| 63 | validate_dataset = globals()[dataset_name]( |
| 64 | conf, conf.data.validate_json_files) |
| 65 | validate_data_loader = DataLoader( |
| 66 | validate_dataset, batch_size=conf.eval.batch_size, shuffle=False, |
| 67 | num_workers=conf.data.num_worker, collate_fn=collate_fn, |
| 68 | pin_memory=True) |
| 69 | |
| 70 | test_dataset = globals()[dataset_name](conf, conf.data.test_json_files) |
| 71 | test_data_loader = DataLoader( |
| 72 | test_dataset, batch_size=conf.eval.batch_size, shuffle=False, |
| 73 | num_workers=conf.data.num_worker, collate_fn=collate_fn, |
| 74 | pin_memory=True) |
| 75 | |
| 76 | return train_data_loader, validate_data_loader, test_data_loader |
| 77 | |
| 78 | |
| 79 | def get_classification_model(model_name, dataset, conf): |