Logical operator for join.
| 43 | |
| 44 | @dataclass(frozen=True, repr=False, eq=False) |
| 45 | class Join(NAry, LogicalOperatorSupportsPredicatePassThrough): |
| 46 | """Logical operator for join.""" |
| 47 | |
| 48 | left_input_op: InitVar[LogicalOperator] |
| 49 | right_input_op: InitVar[LogicalOperator] |
| 50 | join_type: Union[JoinType, str] |
| 51 | left_key_columns: Tuple[str] |
| 52 | right_key_columns: Tuple[str] |
| 53 | num_partitions: InitVar[int] |
| 54 | left_columns_suffix: Optional[str] = None |
| 55 | right_columns_suffix: Optional[str] = None |
| 56 | partition_size_hint: Optional[int] = None |
| 57 | aggregator_ray_remote_args: Optional[Dict[str, Any]] = None |
| 58 | _input_dependencies: list[LogicalOperator] = field(init=False, repr=False) |
| 59 | _num_outputs: Optional[int] = field(init=False, repr=False) |
| 60 | |
| 61 | def __post_init__( |
| 62 | self, |
| 63 | left_input_op: LogicalOperator, |
| 64 | right_input_op: LogicalOperator, |
| 65 | num_partitions: int, |
| 66 | ): |
| 67 | try: |
| 68 | join_type_enum = JoinType(self.join_type) |
| 69 | except ValueError: |
| 70 | raise ValueError( |
| 71 | f"Invalid join type: '{self.join_type}'. " |
| 72 | f"Supported join types are: {', '.join(jt.value for jt in JoinType)}." |
| 73 | ) |
| 74 | |
| 75 | object.__setattr__(self, "join_type", join_type_enum) |
| 76 | object.__setattr__( |
| 77 | self, |
| 78 | "_input_dependencies", |
| 79 | [left_input_op, right_input_op], |
| 80 | ) |
| 81 | object.__setattr__(self, "_num_outputs", num_partitions) |
| 82 | |
| 83 | def _with_new_input_dependencies( |
| 84 | self, input_dependencies: List[LogicalOperator] |
| 85 | ) -> LogicalOperator: |
| 86 | return replace( |
| 87 | self, |
| 88 | left_input_op=input_dependencies[0], |
| 89 | right_input_op=input_dependencies[1], |
| 90 | num_partitions=self.num_outputs, |
| 91 | ) |
| 92 | |
| 93 | @staticmethod |
| 94 | def _validate_schemas( |
| 95 | left_op_schema: "Schema", |
| 96 | right_op_schema: "Schema", |
| 97 | left_key_column_names: Tuple[str], |
| 98 | right_key_column_names: Tuple[str], |
| 99 | ): |
| 100 | def _col_names_as_str(keys: Sequence[str]): |
| 101 | keys_joined = ", ".join(map(lambda k: f"'{k}'", keys)) |
| 102 | return f"[{keys_joined}]" |
no outgoing calls
no test coverage detected
searching dependent graphs…