(adj, func_type="kernel", device="cpu")
| 4924 | |
| 4925 | |
| 4926 | def codegen_func_reverse(adj, func_type="kernel", device="cpu"): |
| 4927 | if device == "cpu": |
| 4928 | indent = 4 |
| 4929 | elif device == "cuda": |
| 4930 | if func_type == "kernel": |
| 4931 | indent = 8 |
| 4932 | else: |
| 4933 | indent = 4 |
| 4934 | else: |
| 4935 | raise ValueError(f"Device {device} not supported for codegen") |
| 4936 | |
| 4937 | indent_block = " " * indent |
| 4938 | |
| 4939 | lines = [] |
| 4940 | |
| 4941 | # argument vars |
| 4942 | if device == "cpu" and func_type == "kernel": |
| 4943 | lines += ["//---------\n"] |
| 4944 | lines += ["// argument vars\n"] |
| 4945 | |
| 4946 | for var in adj.args: |
| 4947 | lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"] |
| 4948 | |
| 4949 | for var in adj.args: |
| 4950 | lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"] |
| 4951 | |
| 4952 | # primal vars |
| 4953 | lines += ["//---------\n"] |
| 4954 | lines += ["// primal vars\n"] |
| 4955 | |
| 4956 | for var in adj.variables: |
| 4957 | if is_tile(var.type): |
| 4958 | lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"] |
| 4959 | elif is_tile_stack(var.type): |
| 4960 | lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit()};\n"] |
| 4961 | elif var.constant is None: |
| 4962 | lines += [f"{var.ctype()} {var.emit()};\n"] |
| 4963 | else: |
| 4964 | lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"] |
| 4965 | |
| 4966 | if line_directive := adj.get_line_directive(lines[-1], var.relative_lineno): |
| 4967 | lines.insert(-1, f"{line_directive}\n") |
| 4968 | |
| 4969 | # dual vars |
| 4970 | lines += ["//---------\n"] |
| 4971 | lines += ["// dual vars\n"] |
| 4972 | |
| 4973 | for var in adj.variables: |
| 4974 | name = var.emit_adj() |
| 4975 | ctype = var.ctype(value_type=True) |
| 4976 | |
| 4977 | if is_tile(var.type): |
| 4978 | if var.type.storage == "register": |
| 4979 | lines += [ |
| 4980 | f"{var.type.ctype()} {name}{{}};\n" |
| 4981 | ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together |
| 4982 | elif var.type.storage == "shared": |
| 4983 | lines += [ |
no test coverage detected