| 15 | |
| 16 | |
| 17 | class PredictConfig(object): |
| 18 | def __init__(self, |
| 19 | model=None, |
| 20 | tower_func=None, |
| 21 | input_signature=None, |
| 22 | |
| 23 | input_names=None, |
| 24 | output_names=None, |
| 25 | |
| 26 | session_creator=None, |
| 27 | session_init=None, |
| 28 | return_input=False, |
| 29 | create_graph=True, |
| 30 | ): |
| 31 | """ |
| 32 | Users need to provide enough arguments to create a tower function, |
| 33 | which will be used to construct the graph. |
| 34 | This can be provided in the following ways: |
| 35 | |
| 36 | 1. `model`: a :class:`ModelDesc` instance. It will contain a tower function by itself. |
| 37 | 2. `tower_func`: a :class:`tfutils.TowerFunc` instance. |
| 38 | Provide a tower function instance directly. |
| 39 | 3. `tower_func`: a symbolic function and `input_signature`: the signature of the function. |
| 40 | Provide both a function and its signature. |
| 41 | |
| 42 | Example: |
| 43 | |
| 44 | .. code-block:: python |
| 45 | |
| 46 | config = PredictConfig(model=my_model, |
| 47 | inputs_names=['image'], |
| 48 | output_names=['linear/output', 'prediction']) |
| 49 | |
| 50 | Args: |
| 51 | model (ModelDescBase): to be used to construct a tower function. |
| 52 | tower_func: a callable which takes input tensors (by positional args) and construct a tower. |
| 53 | or a :class:`tfutils.TowerFunc` instance. |
| 54 | input_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFunc), |
| 55 | this describes the list of inputs it takes. |
| 56 | |
| 57 | input_names (list): a list of input tensor names. Defaults to match input_signature. |
| 58 | The name can be either the name of a tensor, or the name of one input of the tower. |
| 59 | output_names (list): a list of names of the output tensors to predict, the |
| 60 | tensors can be any tensor in the graph that's computable from the tensors correponding to `input_names`. |
| 61 | |
| 62 | session_creator (tf.train.SessionCreator): how to create the |
| 63 | session. Defaults to :class:`NewSessionCreator()`. |
| 64 | session_init (SessionInit): how to initialize variables of the session. |
| 65 | Defaults to do nothing. |
| 66 | |
| 67 | return_input (bool): same as in :attr:`PredictorBase.return_input`. |
| 68 | create_graph (bool): create a new graph, or use the default graph |
| 69 | when predictor is first initialized. |
| 70 | """ |
| 71 | def assert_type(v, tp, name): |
| 72 | assert isinstance(v, tp), \ |
| 73 | "Argument '{}' has to be type '{}', but an object of type '{}' found.".format( |
| 74 | name, tp.__name__, v.__class__.__name__) |
no outgoing calls
no test coverage detected