(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False)
| 239 | """ |
| 240 | |
| 241 | def export_model(model, target:str, *inputs, model_name: Optional[str] = "model", stream_weights=False): |
| 242 | assert Device.DEFAULT in EXPORT_SUPPORTED_DEVICE, f"only {', '.join(EXPORT_SUPPORTED_DEVICE)} are supported" |
| 243 | |
| 244 | # NOTE: CPU_COUNT=1, since export does not support threading |
| 245 | with Context(JIT=2, CPU_COUNT=1): linear, output_bufs = jit_model(model, *inputs) |
| 246 | functions, statements, bufs, bufs_to_save = compile_net(linear, output_bufs) |
| 247 | state = get_state_dict(model) |
| 248 | weight_names = {(id(b), b.offset, b.size, b.dtype): name for name, x in state.items() if (b:=x.uop.base.realized) is not None} |
| 249 | input_names = [f"input{i}" for i in range(len(inputs))] |
| 250 | output_names = [f"output{i}" for i in range(len(output_bufs))] |
| 251 | |
| 252 | # handle symbolic variables; TODO: refactor to fix some of this stuff upstream in tinygrad |
| 253 | symbolic_vars = OrderedDict() |
| 254 | for i, (_, args, global_size, _) in enumerate(statements): |
| 255 | for j, var in enumerate(args): |
| 256 | if getattr(var, "op", None) is Ops.DEFINE_VAR and isinstance(getattr(var, "arg", None), tuple) and isinstance(var.arg[0], str): |
| 257 | if var not in symbolic_vars: |
| 258 | symbolic_vars[var] = var.arg[0] |
| 259 | bufs[symbolic_vars[var]] = (var.dtype.itemsize, var.dtype, symbolic_vars[var]) |
| 260 | statements[i][1][j] = symbolic_vars[var] |
| 261 | |
| 262 | if global_size: |
| 263 | for j, dim in enumerate(global_size): |
| 264 | if getattr(dim, "op", None) is Ops.ADD and len(dim.src) == 2 and {dim.src[0].op, dim.src[1].op} == {Ops.DEFINE_VAR, Ops.CONST}: |
| 265 | name, val = dim.src if dim.src[1].op is Ops.CONST else reversed(dim.src) |
| 266 | global_size[j] = f"_{name.arg[0]}[0] + {val.arg}" |
| 267 | |
| 268 | prg = "" |
| 269 | if target == "clang": |
| 270 | prg = export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names) |
| 271 | elif target == "wasm": |
| 272 | return export_model_clang(functions, statements, bufs, bufs_to_save, input_names, output_names, weight_names, model_name, symbolic_vars, wasm=True) |
| 273 | elif target == "webgpu": |
| 274 | prg = export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars, stream_weights) |
| 275 | else: |
| 276 | prg = json.dumps({ |
| 277 | "backend": Device.DEFAULT, |
| 278 | "inputs": [{ |
| 279 | "size": bufs[name][0], |
| 280 | "dtype": bufs[name][1].name |
| 281 | } for name in input_names], |
| 282 | "outputs": [{ |
| 283 | "size": bufs[name][0], |
| 284 | "dtype": bufs[name][1].name |
| 285 | } for name in output_names], |
| 286 | "functions": functions, |
| 287 | "statements": [{ |
| 288 | "kernel": kernel, |
| 289 | "args": args, |
| 290 | "global_size": global_size, |
| 291 | "local_size": local_size |
| 292 | } for (kernel, args, global_size, local_size) in statements], |
| 293 | "buffers": { |
| 294 | name: { |
| 295 | "size": size, |
| 296 | "dtype": dtype.name, |
| 297 | "id": weight_names[_key] if _key in weight_names else "" |
| 298 | } for name, (size,dtype,_key) in bufs.items() if name not in ["input", "outputs"] |
searching dependent graphs…