MCPcopy
hub / github.com/lululxvi/deepxde / predict

Method predict

deepxde/model.py:932–1044  ·  view source on GitHub ↗

Generates predictions for the input samples. If `operator` is ``None``, returns the network output, otherwise returns the output of the `operator`. Args: x: The network inputs. A Numpy array or a tuple of Numpy arrays. operator: A function takes arguments (`i

(self, x, operator=None, callbacks=None)

Source from the content-addressed store, hash-verified

930 display.training_display(self.train_state)
931
932 def predict(self, x, operator=None, callbacks=None):
933 """Generates predictions for the input samples. If `operator` is ``None``,
934 returns the network output, otherwise returns the output of the `operator`.
935
936 Args:
937 x: The network inputs. A Numpy array or a tuple of Numpy arrays.
938 operator: A function takes arguments (`inputs`, `outputs`) or (`inputs`,
939 `outputs`, `auxiliary_variables`) and outputs a tensor. `inputs` and
940 `outputs` are the network input and output tensors, respectively.
941 `auxiliary_variables` is the output of `auxiliary_var_function(x)`
942 in `dde.data.PDE`. `operator` is typically chosen as the PDE (used to
943 define `dde.data.PDE`) to predict the PDE residual.
944 callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks
945 to apply during prediction.
946 """
947 if isinstance(x, tuple):
948 x = tuple(np.asarray(xi, dtype=config.real(np)) for xi in x)
949 else:
950 x = np.asarray(x, dtype=config.real(np))
951 callbacks = CallbackList(callbacks=callbacks)
952 callbacks.set_model(self)
953 callbacks.on_predict_begin()
954
955 if operator is None:
956 y = self._outputs(False, x)
957 callbacks.on_predict_end()
958 return y
959
960 # operator is not None
961 if utils.get_num_args(operator) == 3:
962 aux_vars = self.data.auxiliary_var_fn(x).astype(config.real(np))
963 if backend_name == "tensorflow.compat.v1":
964 if utils.get_num_args(operator) == 2:
965 op = operator(self.net.inputs, self.net.outputs)
966 feed_dict = self.net.feed_dict(False, x)
967 elif utils.get_num_args(operator) == 3:
968 op = operator(
969 self.net.inputs, self.net.outputs, self.net.auxiliary_vars
970 )
971 feed_dict = self.net.feed_dict(False, x, auxiliary_vars=aux_vars)
972 y = self.sess.run(op, feed_dict=feed_dict)
973 elif backend_name == "tensorflow":
974 if utils.get_num_args(operator) == 2:
975
976 @tf.function
977 def op(inputs):
978 y = self.net(inputs)
979 if config.autodiff == "forward":
980 y = (y, self.net)
981 return operator(inputs, y)
982
983 elif utils.get_num_args(operator) == 3:
984
985 @tf.function
986 def op(inputs):
987 y = self.net(inputs)
988 return operator(inputs, y, aux_vars)
989

Calls 7

set_modelMethod · 0.95
on_predict_beginMethod · 0.95
_outputsMethod · 0.95
on_predict_endMethod · 0.95
CallbackListClass · 0.85
feed_dictMethod · 0.80
clearMethod · 0.45

Tested by

no test coverage detected