(training_dbs, validation_db, start_iter=0)
| 66 | return tasks |
| 67 | |
| 68 | def train(training_dbs, validation_db, start_iter=0): |
| 69 | learning_rate = system_configs.learning_rate |
| 70 | max_iteration = system_configs.max_iter |
| 71 | pretrained_model = system_configs.pretrain |
| 72 | snapshot = system_configs.snapshot |
| 73 | val_iter = system_configs.val_iter |
| 74 | display = system_configs.display |
| 75 | decay_rate = system_configs.decay_rate |
| 76 | stepsize = system_configs.stepsize |
| 77 | |
| 78 | # getting the size of each database |
| 79 | training_size = len(training_dbs[0].db_inds) |
| 80 | validation_size = len(validation_db.db_inds) |
| 81 | |
| 82 | # queues storing data for training |
| 83 | training_queue = Queue(system_configs.prefetch_size) |
| 84 | validation_queue = Queue(5) |
| 85 | |
| 86 | # queues storing pinned data for training |
| 87 | pinned_training_queue = queue.Queue(system_configs.prefetch_size) |
| 88 | pinned_validation_queue = queue.Queue(5) |
| 89 | |
| 90 | # load data sampling function |
| 91 | data_file = "sample.{}".format(training_dbs[0].data) |
| 92 | sample_data = importlib.import_module(data_file).sample_data |
| 93 | |
| 94 | # allocating resources for parallel reading |
| 95 | training_tasks = init_parallel_jobs(training_dbs, training_queue, sample_data, True) |
| 96 | if val_iter: |
| 97 | validation_tasks = init_parallel_jobs([validation_db], validation_queue, sample_data, False) |
| 98 | |
| 99 | training_pin_semaphore = threading.Semaphore() |
| 100 | validation_pin_semaphore = threading.Semaphore() |
| 101 | training_pin_semaphore.acquire() |
| 102 | validation_pin_semaphore.acquire() |
| 103 | |
| 104 | training_pin_args = (training_queue, pinned_training_queue, training_pin_semaphore) |
| 105 | training_pin_thread = threading.Thread(target=pin_memory, args=training_pin_args) |
| 106 | training_pin_thread.daemon = True |
| 107 | training_pin_thread.start() |
| 108 | |
| 109 | validation_pin_args = (validation_queue, pinned_validation_queue, validation_pin_semaphore) |
| 110 | validation_pin_thread = threading.Thread(target=pin_memory, args=validation_pin_args) |
| 111 | validation_pin_thread.daemon = True |
| 112 | validation_pin_thread.start() |
| 113 | |
| 114 | print("building model...") |
| 115 | nnet = NetworkFactory(training_dbs[0]) |
| 116 | |
| 117 | if pretrained_model is not None: |
| 118 | if not os.path.exists(pretrained_model): |
| 119 | raise ValueError("pretrained model does not exist") |
| 120 | print("loading from pretrained model") |
| 121 | nnet.load_pretrained_params(pretrained_model) |
| 122 | |
| 123 | if start_iter: |
| 124 | learning_rate /= (decay_rate ** (start_iter // stepsize)) |
| 125 |
no test coverage detected