Args: dataflow (DataFlow): data (InputSource): model (ModelDesc): callbacks (list[Callback]): a list of :class:`Callback` to use during training. extra_callbacks (list[Callback]): This argument is only used to prov
(self,
dataflow=None, data=None,
model=None,
callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999)
| 57 | """ |
| 58 | |
| 59 | def __init__(self, |
| 60 | dataflow=None, data=None, |
| 61 | model=None, |
| 62 | callbacks=None, extra_callbacks=None, monitors=None, |
| 63 | session_creator=None, session_config=None, session_init=None, |
| 64 | starting_epoch=1, steps_per_epoch=None, max_epoch=99999): |
| 65 | """ |
| 66 | Args: |
| 67 | dataflow (DataFlow): |
| 68 | data (InputSource): |
| 69 | model (ModelDesc): |
| 70 | |
| 71 | callbacks (list[Callback]): a list of :class:`Callback` to use during training. |
| 72 | extra_callbacks (list[Callback]): This argument |
| 73 | is only used to provide the defaults in addition to ``callbacks``. |
| 74 | The list of callbacks that will be used in the end is simply ``callbacks + extra_callbacks``. |
| 75 | |
| 76 | It is usually left as None, and the default value for this argument is :func:`DEFAULT_CALLBACKS()`. |
| 77 | You can override it when you don't like any of the default callbacks. |
| 78 | For example, if you'd like to let the progress bar print tensors, you can use |
| 79 | |
| 80 | .. code-block:: none |
| 81 | |
| 82 | extra_callbacks=[ProgressBar(names=['name']), |
| 83 | MovingAverageSummary(), |
| 84 | MergeAllSummaries(), |
| 85 | RunUpdateOps()] |
| 86 | |
| 87 | monitors (list[MonitorBase]): Defaults to :func:`DEFAULT_MONITORS()`. |
| 88 | |
| 89 | session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()` |
| 90 | with the config returned by :func:`tfutils.get_default_sess_config()`. |
| 91 | session_config (tf.ConfigProto): when session_creator is None, use this to create the session. |
| 92 | session_init (SessionInit): how to initialize variables of a session. Defaults to do nothing. |
| 93 | |
| 94 | starting_epoch (int): The index of the first epoch. |
| 95 | steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch. |
| 96 | Defaults to the input data size. You may want to divide it by the #GPUs in multi-GPU training. |
| 97 | |
| 98 | Number of steps per epoch only affects the schedule of callbacks. |
| 99 | It does not affect the sequence of input data seen by the model. |
| 100 | max_epoch (int): maximum number of epoch to run training. |
| 101 | """ |
| 102 | |
| 103 | # TODO type checker decorator |
| 104 | def assert_type(v, tp, name): |
| 105 | assert isinstance(v, tp), \ |
| 106 | "{} has to be type '{}', but an object of type '{}' found.".format( |
| 107 | name, tp.__name__, v.__class__.__name__) |
| 108 | |
| 109 | # process data & model |
| 110 | assert data is None or dataflow is None, "dataflow and data cannot be both presented in TrainConfig!" |
| 111 | if dataflow is not None: |
| 112 | assert_type(dataflow, DataFlow, 'dataflow') |
| 113 | if data is not None: |
| 114 | assert_type(data, InputSource, 'data') |
| 115 | self.dataflow = dataflow |
| 116 | self.data = data |
no test coverage detected