(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False)
| 113 | return f"{'Uint' if dtype in dtypes.uints else 'Int' if (dtype in dtypes.sints or dtype == dtypes.bool) else 'Float'}{8*dtype.itemsize}Array" |
| 114 | |
| 115 | def export_model_webgpu(functions, statements, bufs, weight_names, input_names, output_names, model_name, symbolic_vars={}, stream_weights=False) -> Tuple[str,int,int]: |
| 116 | kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()]) |
| 117 | kernel_names = ', '.join([name for (name, _, _, _) in statements]) |
| 118 | input_names += list(symbolic_vars.values()) |
| 119 | input_buffer_types = [dtype_to_js_type(bufs[inp_name][1]) for inp_name in input_names] |
| 120 | output_buffer_types = [dtype_to_js_type(bufs[out_name][1]) for out_name in output_names] |
| 121 | |
| 122 | buf_type = lambda x: "uniform" if x in set(symbolic_vars.values()) else "storage" |
| 123 | create_bind_group_layouts = ",".join([ |
| 124 | "device.createBindGroupLayout({{entries: [{{binding: 0, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: 'uniform' }}}}, {}]}})".format( |
| 125 | ",".join([f"{{binding: {argIdx+1}, visibility: GPUShaderStage.COMPUTE, buffer: {{ type: '{buf_type(argName)}' }} }}" for argIdx, argName in enumerate(args)]) |
| 126 | ) |
| 127 | for _, (_, args, _, _) in enumerate(statements) |
| 128 | ]) |
| 129 | layouts = f"const layouts=[{create_bind_group_layouts}]" |
| 130 | kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, pipelines[{i}], layouts[{i}], infinityBuf, [{', '.join(args)}], [{', '.join(str(x) for x in global_size)}]);" for i, (_name, args, global_size, _local_size) in enumerate(statements) ]) |
| 131 | |
| 132 | buf_type = lambda x: "createUniformBuf" if x in set(uop.arg[0] for uop in symbolic_vars) else "createEmptyBuf" |
| 133 | map_to_external_weight = lambda _key: f"state_dict['{weight_names[_key]}']" if stream_weights else f"getTensorBuffer(safetensor, metadata['{weight_names[_key]}'])" |
| 134 | _bufs = '\n '.join([f"const {name} = " + (f"{buf_type(_key)}(device, {size});" if _key not in weight_names else f"createWeightBuf(device, {size}, {map_to_external_weight(_key)})") + ";" for name,(size,dtype,_key) in bufs.items()]) |
| 135 | gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:{input_name}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,input_name in enumerate(input_names)]) |
| 136 | input_writers = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new {input_buffer_types[i]}(gpuWriteBuffer{i}.getMappedRange()).set(" + f'_{inp_name});' + f"\n gpuWriteBuffer{i}.unmap();\n commandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, {inp_name}, 0, gpuWriteBuffer{i}.size);" for i,inp_name in enumerate(input_names)]) |
| 137 | gpu_read_bufs = '\n '.join([f"const gpuReadBuffer{i} = device.createBuffer({{size:{output_name}.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});" for i,output_name in enumerate(output_names)]) |
| 138 | outbuf_copies = '\n '.join([f"commandEncoder.copyBufferToBuffer({output_name}, 0, gpuReadBuffer{i}, 0, output{i}.size);" for i,output_name in enumerate(output_names)]) |
| 139 | output_readers = '\n '.join([f"await gpuReadBuffer{i}.mapAsync(GPUMapMode.READ);\n const resultBuffer{i} = new {output_buffer_types[i]}(gpuReadBuffer{i}.size/{bufs[output_names[i]][1].itemsize});\n resultBuffer{i}.set(new {output_buffer_types[i]}(gpuReadBuffer{i}.getMappedRange()));\n gpuReadBuffer{i}.unmap();" for i in range(len(output_names))]) |
| 140 | output_return = '[{}]'.format(",".join([f'resultBuffer{i}' for i in range(len(output_names))])) |
| 141 | getTensorMetadata = f"""\nconst getTensorMetadata = (safetensorBuffer) => {{ |
| 142 | const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true)); |
| 143 | const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength))); |
| 144 | return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}])); |
| 145 | }};\n""" if not stream_weights else "" |
| 146 | return f""" |
| 147 | const {model_name} = (() => {{ |
| 148 | const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{ |
| 149 | return safetensorBuffer.subarray(...tensorMetadata.data_offsets); |
| 150 | }}; |
| 151 | {getTensorMetadata} |
| 152 | const createEmptyBuf = (device, size) => {{ |
| 153 | return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }}); |
| 154 | }}; |
| 155 | |
| 156 | const createUniformBuf = (device, size) => {{ |
| 157 | return device.createBuffer({{size, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST}}) |
| 158 | }} |
| 159 | |
| 160 | const createInfinityUniformBuf = (device) => {{ |
| 161 | const size = 4; |
| 162 | const buf = device.createBuffer({{ |
| 163 | mappedAtCreation: true, |
| 164 | size, |
| 165 | usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST |
| 166 | }}); |
| 167 | new Float32Array(buf.getMappedRange())[0] = Infinity; |
| 168 | buf.unmap(); |
| 169 | return buf; |
| 170 | }}; |
| 171 | |
| 172 | const createWeightBuf = (device, size, data) => {{ |
no test coverage detected
searching dependent graphs…