Generates predictions for the input samples. If `operator` is ``None``, returns the network output, otherwise returns the output of the `operator`. Args: x: The network inputs. A Numpy array or a tuple of Numpy arrays. operator: A function takes arguments (`i
(self, x, operator=None, callbacks=None)
| 930 | display.training_display(self.train_state) |
| 931 | |
| 932 | def predict(self, x, operator=None, callbacks=None): |
| 933 | """Generates predictions for the input samples. If `operator` is ``None``, |
| 934 | returns the network output, otherwise returns the output of the `operator`. |
| 935 | |
| 936 | Args: |
| 937 | x: The network inputs. A Numpy array or a tuple of Numpy arrays. |
| 938 | operator: A function takes arguments (`inputs`, `outputs`) or (`inputs`, |
| 939 | `outputs`, `auxiliary_variables`) and outputs a tensor. `inputs` and |
| 940 | `outputs` are the network input and output tensors, respectively. |
| 941 | `auxiliary_variables` is the output of `auxiliary_var_function(x)` |
| 942 | in `dde.data.PDE`. `operator` is typically chosen as the PDE (used to |
| 943 | define `dde.data.PDE`) to predict the PDE residual. |
| 944 | callbacks: List of ``dde.callbacks.Callback`` instances. List of callbacks |
| 945 | to apply during prediction. |
| 946 | """ |
| 947 | if isinstance(x, tuple): |
| 948 | x = tuple(np.asarray(xi, dtype=config.real(np)) for xi in x) |
| 949 | else: |
| 950 | x = np.asarray(x, dtype=config.real(np)) |
| 951 | callbacks = CallbackList(callbacks=callbacks) |
| 952 | callbacks.set_model(self) |
| 953 | callbacks.on_predict_begin() |
| 954 | |
| 955 | if operator is None: |
| 956 | y = self._outputs(False, x) |
| 957 | callbacks.on_predict_end() |
| 958 | return y |
| 959 | |
| 960 | # operator is not None |
| 961 | if utils.get_num_args(operator) == 3: |
| 962 | aux_vars = self.data.auxiliary_var_fn(x).astype(config.real(np)) |
| 963 | if backend_name == "tensorflow.compat.v1": |
| 964 | if utils.get_num_args(operator) == 2: |
| 965 | op = operator(self.net.inputs, self.net.outputs) |
| 966 | feed_dict = self.net.feed_dict(False, x) |
| 967 | elif utils.get_num_args(operator) == 3: |
| 968 | op = operator( |
| 969 | self.net.inputs, self.net.outputs, self.net.auxiliary_vars |
| 970 | ) |
| 971 | feed_dict = self.net.feed_dict(False, x, auxiliary_vars=aux_vars) |
| 972 | y = self.sess.run(op, feed_dict=feed_dict) |
| 973 | elif backend_name == "tensorflow": |
| 974 | if utils.get_num_args(operator) == 2: |
| 975 | |
| 976 | @tf.function |
| 977 | def op(inputs): |
| 978 | y = self.net(inputs) |
| 979 | if config.autodiff == "forward": |
| 980 | y = (y, self.net) |
| 981 | return operator(inputs, y) |
| 982 | |
| 983 | elif utils.get_num_args(operator) == 3: |
| 984 | |
| 985 | @tf.function |
| 986 | def op(inputs): |
| 987 | y = self.net(inputs) |
| 988 | return operator(inputs, y, aux_vars) |
| 989 |
no test coverage detected