Train the model. train_dataset, val_dataset: Training and validation Dataset objects. learning_rate: The learning rate to train with epochs: Number of training epochs. Note that previous training epochs are considered to be done alreay, so this actually determ
(self, train_dataset, val_dataset, learning_rate, epochs, layers)
| 2169 | "*epoch*", "{epoch:04d}") |
| 2170 | |
| 2171 | def train(self, train_dataset, val_dataset, learning_rate, epochs, layers): |
| 2172 | """Train the model. |
| 2173 | train_dataset, val_dataset: Training and validation Dataset objects. |
| 2174 | learning_rate: The learning rate to train with |
| 2175 | epochs: Number of training epochs. Note that previous training epochs |
| 2176 | are considered to be done alreay, so this actually determines |
| 2177 | the epochs to train in total rather than in this particaular |
| 2178 | call. |
| 2179 | layers: Allows selecting wich layers to train. It can be: |
| 2180 | - A regular expression to match layer names to train |
| 2181 | - One of these predefined values: |
| 2182 | heaads: The RPN, classifier and mask heads of the network |
| 2183 | all: All the layers |
| 2184 | 3+: Train Resnet stage 3 and up |
| 2185 | 4+: Train Resnet stage 4 and up |
| 2186 | 5+: Train Resnet stage 5 and up |
| 2187 | """ |
| 2188 | assert self.mode == "training", "Create model in training mode." |
| 2189 | |
| 2190 | # Pre-defined layer regular expressions |
| 2191 | layer_regex = { |
| 2192 | # all layers but the backbone |
| 2193 | "heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)", |
| 2194 | # From a specific Resnet stage and up |
| 2195 | "3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)", |
| 2196 | "4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)", |
| 2197 | "5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)", |
| 2198 | # All layers |
| 2199 | "all": ".*", |
| 2200 | } |
| 2201 | if layers in layer_regex.keys(): |
| 2202 | layers = layer_regex[layers] |
| 2203 | |
| 2204 | # Data generators |
| 2205 | train_generator = data_generator(train_dataset, self.config, shuffle=True, |
| 2206 | batch_size=self.config.BATCH_SIZE) |
| 2207 | val_generator = data_generator(val_dataset, self.config, shuffle=True, |
| 2208 | batch_size=self.config.BATCH_SIZE, |
| 2209 | augment=False) |
| 2210 | |
| 2211 | # Callbacks |
| 2212 | callbacks = [ |
| 2213 | keras.callbacks.TensorBoard(log_dir=self.log_dir, |
| 2214 | histogram_freq=0, write_graph=True, write_images=False), |
| 2215 | keras.callbacks.ModelCheckpoint(self.checkpoint_path, |
| 2216 | verbose=0, save_weights_only=True), |
| 2217 | ] |
| 2218 | |
| 2219 | # Train |
| 2220 | log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate)) |
| 2221 | log("Checkpoint Path: {}".format(self.checkpoint_path)) |
| 2222 | self.set_trainable(layers) |
| 2223 | self.compile(learning_rate, self.config.LEARNING_MOMENTUM) |
| 2224 | |
| 2225 | # Work-around for Windows: Keras fails on Windows when using |
| 2226 | # multiprocessing workers. See discussion here: |
| 2227 | # https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009 |
| 2228 | if os.name is 'nt': |
no test coverage detected