See tf.where. Only three-argument version is supported here.
(cls, condition, tensor_wrapper_1, tensor_wrapper_2)
| 71 | |
| 72 | @classmethod |
| 73 | def where(cls, condition, tensor_wrapper_1, tensor_wrapper_2): |
| 74 | """See tf.where. Only three-argument version is supported here.""" |
| 75 | tensor_wrappers = [tensor_wrapper_1, tensor_wrapper_2] |
| 76 | cls._validate_tensor_types(tensor_wrappers, "where") |
| 77 | return cls._apply_sequence_to_tensor_op( |
| 78 | lambda ts: tf.compat.v2.where(condition, ts[0], ts[1]), tensor_wrappers) |
| 79 | |
| 80 | @classmethod |
| 81 | def _validate_tensor_types(cls, tensor_wrappers, function_name): |