(
f, grad_f=None, python_func_type=None, pass_workspace=False
)
| 423 | |
| 424 | |
| 425 | def _RegisterPythonImpl( |
| 426 | f, grad_f=None, python_func_type=None, pass_workspace=False |
| 427 | ): |
| 428 | if python_func_type: |
| 429 | func = python_func_type(f) |
| 430 | f = func.forward |
| 431 | grad_f = func.backward |
| 432 | else: |
| 433 | if isinstance(f, tuple): |
| 434 | f = f[0](*f[1], **f[2]) |
| 435 | if isinstance(grad_f, tuple): |
| 436 | grad_f = grad_f[0](*grad_f[1], **grad_f[2]) |
| 437 | |
| 438 | token = C.register_python_op(f, pass_workspace, '') |
| 439 | if grad_f: |
| 440 | C.register_python_gradient_op(token, grad_f) |
| 441 | return token |
| 442 | |
| 443 | |
| 444 | def CreatePythonOperator( |
no test coverage detected
searching dependent graphs…