(model)
| 7 | from tiktoken.load import load_tiktoken_bpe, dump_tiktoken_bpe |
| 8 | |
| 9 | def prepare_browser_chunks(model): |
| 10 | # split weights into browser-friendly chunks |
| 11 | state_dict = get_state_dict(model) |
| 12 | del state_dict['output.weight'], state_dict['output.scale'] # same as tok_embeddings; ensures consistency with model export |
| 13 | chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints |
| 14 | metadata = {} |
| 15 | # We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be |
| 16 | t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k] |
| 17 | empty_t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k] |
| 18 | |
| 19 | split_t_infos = [] |
| 20 | for size, name, dtype in t_infos: |
| 21 | if size <= chunk_size: |
| 22 | split_t_infos.append((size, name, dtype, ())) |
| 23 | else: # split large weights into multiple parts |
| 24 | for i in range(0, size, chunk_size): |
| 25 | split_t_infos.append((min(chunk_size, size-i), f"{name}_part{math.ceil(i/chunk_size)}", dtype, (i, min(i+chunk_size, size)))) |
| 26 | |
| 27 | files = [] |
| 28 | # pack weights into files with FFD bin packing |
| 29 | split_t_infos = sorted(split_t_infos, reverse=True) |
| 30 | for info in split_t_infos: |
| 31 | placed = False |
| 32 | for file in files: |
| 33 | if sum(i[0] for i in file) + info[0] <= chunk_size: |
| 34 | if info[3] and any(i[3] for i in file): continue # no two split tensors can touch the same file, due to wasm loading constraints |
| 35 | file.append(info) |
| 36 | placed = True |
| 37 | break |
| 38 | if not placed: |
| 39 | files.append([info]) |
| 40 | |
| 41 | tinygrad_dtypes = {dtypes.float32: "float32", dtypes.float16: "float16", dtypes.int8: "int8", dtypes.int32: "int32"} |
| 42 | for i, file in enumerate(files): |
| 43 | cursor = 0 |
| 44 | with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "wb+") as writer: |
| 45 | for size, name, dtype, offsets in file: |
| 46 | name, part_num = (name, 0) if "_part" not in name else (name.split("_part")[0], int(name.split("_part")[1])) |
| 47 | default = {"parts": {}, "dtype": tinygrad_dtypes[dtype]} |
| 48 | weight_metadata = metadata.get(name, default) |
| 49 | weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size} |
| 50 | metadata[name] = weight_metadata |
| 51 | data = bytes(state_dict[name].uop.base.realized.as_memoryview()) |
| 52 | data = data if not offsets else data[offsets[0]:offsets[1]] |
| 53 | writer.write(data) |
| 54 | cursor += size |
| 55 | |
| 56 | metadata.update({name: {"parts": {0: {"empty": True, "size": size}}, "dtype": tinygrad_dtypes[dtype]} for size, name, dtype in empty_t_infos}) |
| 57 | |
| 58 | for k in metadata: |
| 59 | metadata[k]["parts"] = [part for part_num, part in sorted(metadata[k]["parts"].items(), key = lambda x: x[0])] |
| 60 | cursor = 0 |
| 61 | for i, part in enumerate(metadata[k]["parts"]): |
| 62 | metadata[k]["parts"][i]["target_start_pos"] = cursor |
| 63 | cursor += part["size"] |
| 64 | metadata[k]["size"] = cursor |
| 65 | |
| 66 | # compute hashes, which client app will check to determine whether to update with new weights and/or detect integrity issues |
no test coverage detected
searching dependent graphs…