Class to store results from the matcher. This class is used to store the results from the matcher. It provides convenient methods to query the matching results.
| 37 | |
| 38 | |
| 39 | class Match(object): |
| 40 | """Class to store results from the matcher. |
| 41 | |
| 42 | This class is used to store the results from the matcher. It provides |
| 43 | convenient methods to query the matching results. |
| 44 | """ |
| 45 | |
| 46 | def __init__(self, match_results): |
| 47 | """Constructs a Match object. |
| 48 | |
| 49 | Args: |
| 50 | match_results: Integer tensor of shape [N] with (1) match_results[i]>=0, |
| 51 | meaning that column i is matched with row match_results[i]. (2) |
| 52 | match_results[i]=-1, meaning that column i is not matched. (3) |
| 53 | match_results[i]=-2, meaning that column i is ignored. |
| 54 | |
| 55 | Raises: |
| 56 | ValueError: if match_results does not have rank 1 or is not an |
| 57 | integer int32 scalar tensor |
| 58 | """ |
| 59 | if match_results.shape.ndims != 1: |
| 60 | raise ValueError('match_results should have rank 1') |
| 61 | if match_results.dtype != tf.int32: |
| 62 | raise ValueError('match_results should be an int32 or int64 scalar ' |
| 63 | 'tensor') |
| 64 | self._match_results = match_results |
| 65 | |
| 66 | @property |
| 67 | def match_results(self): |
| 68 | """The accessor for match results. |
| 69 | |
| 70 | Returns: |
| 71 | the tensor which encodes the match results. |
| 72 | """ |
| 73 | return self._match_results |
| 74 | |
| 75 | def matched_column_indices(self): |
| 76 | """Returns column indices that match to some row. |
| 77 | |
| 78 | The indices returned by this op are always sorted in increasing order. |
| 79 | |
| 80 | Returns: |
| 81 | column_indices: int32 tensor of shape [K] with column indices. |
| 82 | """ |
| 83 | return self._reshape_and_cast(tf.where(tf.greater(self._match_results, -1))) |
| 84 | |
| 85 | def matched_column_indicator(self): |
| 86 | """Returns column indices that are matched. |
| 87 | |
| 88 | Returns: |
| 89 | column_indices: int32 tensor of shape [K] with column indices. |
| 90 | """ |
| 91 | return tf.greater_equal(self._match_results, 0) |
| 92 | |
| 93 | def num_matched_columns(self): |
| 94 | """Returns number (int32 scalar tensor) of matched columns.""" |
| 95 | return tf.size(input=self.matched_column_indices()) |
| 96 |