MCPcopy
hub / github.com/Duankaiwen/CenterNet / train

Function train

train.py:68–173  ·  view source on GitHub ↗
(training_dbs, validation_db, start_iter=0)

Source from the content-addressed store, hash-verified

66 return tasks
67
68def train(training_dbs, validation_db, start_iter=0):
69 learning_rate = system_configs.learning_rate
70 max_iteration = system_configs.max_iter
71 pretrained_model = system_configs.pretrain
72 snapshot = system_configs.snapshot
73 val_iter = system_configs.val_iter
74 display = system_configs.display
75 decay_rate = system_configs.decay_rate
76 stepsize = system_configs.stepsize
77
78 # getting the size of each database
79 training_size = len(training_dbs[0].db_inds)
80 validation_size = len(validation_db.db_inds)
81
82 # queues storing data for training
83 training_queue = Queue(system_configs.prefetch_size)
84 validation_queue = Queue(5)
85
86 # queues storing pinned data for training
87 pinned_training_queue = queue.Queue(system_configs.prefetch_size)
88 pinned_validation_queue = queue.Queue(5)
89
90 # load data sampling function
91 data_file = "sample.{}".format(training_dbs[0].data)
92 sample_data = importlib.import_module(data_file).sample_data
93
94 # allocating resources for parallel reading
95 training_tasks = init_parallel_jobs(training_dbs, training_queue, sample_data, True)
96 if val_iter:
97 validation_tasks = init_parallel_jobs([validation_db], validation_queue, sample_data, False)
98
99 training_pin_semaphore = threading.Semaphore()
100 validation_pin_semaphore = threading.Semaphore()
101 training_pin_semaphore.acquire()
102 validation_pin_semaphore.acquire()
103
104 training_pin_args = (training_queue, pinned_training_queue, training_pin_semaphore)
105 training_pin_thread = threading.Thread(target=pin_memory, args=training_pin_args)
106 training_pin_thread.daemon = True
107 training_pin_thread.start()
108
109 validation_pin_args = (validation_queue, pinned_validation_queue, validation_pin_semaphore)
110 validation_pin_thread = threading.Thread(target=pin_memory, args=validation_pin_args)
111 validation_pin_thread.daemon = True
112 validation_pin_thread.start()
113
114 print("building model...")
115 nnet = NetworkFactory(training_dbs[0])
116
117 if pretrained_model is not None:
118 if not os.path.exists(pretrained_model):
119 raise ValueError("pretrained model does not exist")
120 print("loading from pretrained model")
121 nnet.load_pretrained_params(pretrained_model)
122
123 if start_iter:
124 learning_rate /= (decay_rate ** (start_iter // stepsize))
125

Callers 1

train.pyFile · 0.85

Calls 12

load_paramsMethod · 0.95
set_lrMethod · 0.95
cudaMethod · 0.95
train_modeMethod · 0.95
trainMethod · 0.95
eval_modeMethod · 0.95
validateMethod · 0.95
save_paramsMethod · 0.95
NetworkFactoryClass · 0.90
stdout_to_tqdmFunction · 0.90
init_parallel_jobsFunction · 0.85

Tested by

no test coverage detected