MCPcopy
hub / github.com/matterport/Mask_RCNN / train

Method train

model.py:2171–2245  ·  view source on GitHub ↗

Train the model. train_dataset, val_dataset: Training and validation Dataset objects. learning_rate: The learning rate to train with epochs: Number of training epochs. Note that previous training epochs are considered to be done alreay, so this actually determ

(self, train_dataset, val_dataset, learning_rate, epochs, layers)

Source from the content-addressed store, hash-verified

2169 "*epoch*", "{epoch:04d}")
2170
2171 def train(self, train_dataset, val_dataset, learning_rate, epochs, layers):
2172 """Train the model.
2173 train_dataset, val_dataset: Training and validation Dataset objects.
2174 learning_rate: The learning rate to train with
2175 epochs: Number of training epochs. Note that previous training epochs
2176 are considered to be done alreay, so this actually determines
2177 the epochs to train in total rather than in this particaular
2178 call.
2179 layers: Allows selecting wich layers to train. It can be:
2180 - A regular expression to match layer names to train
2181 - One of these predefined values:
2182 heaads: The RPN, classifier and mask heads of the network
2183 all: All the layers
2184 3+: Train Resnet stage 3 and up
2185 4+: Train Resnet stage 4 and up
2186 5+: Train Resnet stage 5 and up
2187 """
2188 assert self.mode == "training", "Create model in training mode."
2189
2190 # Pre-defined layer regular expressions
2191 layer_regex = {
2192 # all layers but the backbone
2193 "heads": r"(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
2194 # From a specific Resnet stage and up
2195 "3+": r"(res3.*)|(bn3.*)|(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
2196 "4+": r"(res4.*)|(bn4.*)|(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
2197 "5+": r"(res5.*)|(bn5.*)|(mrcnn\_.*)|(rpn\_.*)|(fpn\_.*)",
2198 # All layers
2199 "all": ".*",
2200 }
2201 if layers in layer_regex.keys():
2202 layers = layer_regex[layers]
2203
2204 # Data generators
2205 train_generator = data_generator(train_dataset, self.config, shuffle=True,
2206 batch_size=self.config.BATCH_SIZE)
2207 val_generator = data_generator(val_dataset, self.config, shuffle=True,
2208 batch_size=self.config.BATCH_SIZE,
2209 augment=False)
2210
2211 # Callbacks
2212 callbacks = [
2213 keras.callbacks.TensorBoard(log_dir=self.log_dir,
2214 histogram_freq=0, write_graph=True, write_images=False),
2215 keras.callbacks.ModelCheckpoint(self.checkpoint_path,
2216 verbose=0, save_weights_only=True),
2217 ]
2218
2219 # Train
2220 log("\nStarting at epoch {}. LR={}\n".format(self.epoch, learning_rate))
2221 log("Checkpoint Path: {}".format(self.checkpoint_path))
2222 self.set_trainable(layers)
2223 self.compile(learning_rate, self.config.LEARNING_MOMENTUM)
2224
2225 # Work-around for Windows: Keras fails on Windows when using
2226 # multiprocessing workers. See discussion here:
2227 # https://github.com/matterport/Mask_RCNN/issues/13#issuecomment-353124009
2228 if os.name is 'nt':

Callers 2

coco.pyFile · 0.80
trainFunction · 0.80

Calls 4

set_trainableMethod · 0.95
compileMethod · 0.95
data_generatorFunction · 0.85
logFunction · 0.85

Tested by

no test coverage detected