Regitser a custimized function as operation handler. Function should accept at least 3 input parameters, return one or more tensor as result: func(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: If there is already another
(handler: Callable, operation_type: str, platform: TargetPlatform)
| 23 | |
| 24 | |
| 25 | def register_operation_handler(handler: Callable, operation_type: str, platform: TargetPlatform): |
| 26 | """Regitser a custimized function as operation handler. |
| 27 | |
| 28 | Function should accept at least 3 input parameters, return one or more tensor as result: |
| 29 | func(op: Operation, values: List[torch.Tensor], ctx: TorchBackendContext = None, **kwargs) -> torch.Tensor: |
| 30 | |
| 31 | If there is already another operation handler for given operation_type, |
| 32 | new handler will replace the old one without warrning. |
| 33 | |
| 34 | Args: |
| 35 | handler (Callable): Callable function, which interface follows restriction above. |
| 36 | operation_type (str): Operation type. |
| 37 | platform (TargetPlatform): Register platform. |
| 38 | """ |
| 39 | if platform not in OPERATION_FORWARD_TABLE: |
| 40 | raise ValueError('Unknown Platform detected, Please check your platform setting.') |
| 41 | OPERATION_FORWARD_TABLE[platform][operation_type] = handler |
| 42 | |
| 43 | |
| 44 | class RuntimeHook(metaclass=ABCMeta): |
no outgoing calls
no test coverage detected