(functional, input_type, *, allow_passthrough=False)
| 97 | |
| 98 | |
| 99 | def _get_kernel(functional, input_type, *, allow_passthrough=False): |
| 100 | registry = _KERNEL_REGISTRY.get(functional) |
| 101 | if not registry: |
| 102 | raise ValueError(f"No kernel registered for functional {functional.__name__}.") |
| 103 | |
| 104 | for cls in input_type.__mro__: |
| 105 | if cls in registry: |
| 106 | return registry[cls] |
| 107 | elif cls is tv_tensors.TVTensor: |
| 108 | # We don't want user-defined tv_tensors to dispatch to the pure Tensor kernels, so we explicit stop the |
| 109 | # MRO traversal before hitting torch.Tensor. We can even stop at tv_tensors.TVTensor, since we don't |
| 110 | # allow kernels to be registered for tv_tensors.TVTensor anyway. |
| 111 | break |
| 112 | |
| 113 | if allow_passthrough: |
| 114 | return lambda inpt, *args, **kwargs: inpt |
| 115 | |
| 116 | raise TypeError( |
| 117 | f"Functional F.{functional.__name__} supports inputs of type {registry.keys()}, " |
| 118 | f"but got {input_type} instead." |
| 119 | ) |
| 120 | |
| 121 | |
| 122 | # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop |
searching dependent graphs…