MCPcopy
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / get_data_loader

Function get_data_loader

train.py:51–76  ·  view source on GitHub ↗

Get data loader: Train, Validate, Test

(dataset_name, collate_name, conf)

Source from the content-addressed store, hash-verified

49
50
51def get_data_loader(dataset_name, collate_name, conf):
52 """Get data loader: Train, Validate, Test
53 """
54 train_dataset = globals()[dataset_name](
55 conf, conf.data.train_json_files, generate_dict=True)
56 collate_fn = globals()[collate_name](conf, len(train_dataset.label_map))
57
58 train_data_loader = DataLoader(
59 train_dataset, batch_size=conf.train.batch_size, shuffle=True,
60 num_workers=conf.data.num_worker, collate_fn=collate_fn,
61 pin_memory=True)
62
63 validate_dataset = globals()[dataset_name](
64 conf, conf.data.validate_json_files)
65 validate_data_loader = DataLoader(
66 validate_dataset, batch_size=conf.eval.batch_size, shuffle=False,
67 num_workers=conf.data.num_worker, collate_fn=collate_fn,
68 pin_memory=True)
69
70 test_dataset = globals()[dataset_name](conf, conf.data.test_json_files)
71 test_data_loader = DataLoader(
72 test_dataset, batch_size=conf.eval.batch_size, shuffle=False,
73 num_workers=conf.data.num_worker, collate_fn=collate_fn,
74 pin_memory=True)
75
76 return train_data_loader, validate_data_loader, test_data_loader
77
78
79def get_classification_model(model_name, dataset, conf):

Callers 1

trainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected