MCPcopy Index your code
hub / github.com/tinygrad/tinygrad / prepare_browser_chunks

Function prepare_browser_chunks

examples/tinychat/tinychat-browser/compile.py:9–80  ·  view source on GitHub ↗
(model)

Source from the content-addressed store, hash-verified

7from tiktoken.load import load_tiktoken_bpe, dump_tiktoken_bpe
8
9def prepare_browser_chunks(model):
10 # split weights into browser-friendly chunks
11 state_dict = get_state_dict(model)
12 del state_dict['output.weight'], state_dict['output.scale'] # same as tok_embeddings; ensures consistency with model export
13 chunk_size = 16 * 1024 * 1024 # small chunks based on iphone browser constraints
14 metadata = {}
15 # We won't export cache_kv bytes (because we start inference on client at start_pos=0), but we will tell the client how big cache_kv needs to be
16 t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" not in k]
17 empty_t_infos = [(v.uop.base.realized.nbytes, k, v.dtype) for k,v in state_dict.items() if "cache_kv" in k]
18
19 split_t_infos = []
20 for size, name, dtype in t_infos:
21 if size <= chunk_size:
22 split_t_infos.append((size, name, dtype, ()))
23 else: # split large weights into multiple parts
24 for i in range(0, size, chunk_size):
25 split_t_infos.append((min(chunk_size, size-i), f"{name}_part{math.ceil(i/chunk_size)}", dtype, (i, min(i+chunk_size, size))))
26
27 files = []
28 # pack weights into files with FFD bin packing
29 split_t_infos = sorted(split_t_infos, reverse=True)
30 for info in split_t_infos:
31 placed = False
32 for file in files:
33 if sum(i[0] for i in file) + info[0] <= chunk_size:
34 if info[3] and any(i[3] for i in file): continue # no two split tensors can touch the same file, due to wasm loading constraints
35 file.append(info)
36 placed = True
37 break
38 if not placed:
39 files.append([info])
40
41 tinygrad_dtypes = {dtypes.float32: "float32", dtypes.float16: "float16", dtypes.int8: "int8", dtypes.int32: "int32"}
42 for i, file in enumerate(files):
43 cursor = 0
44 with open(os.path.join(os.path.dirname(__file__), f'./net_part{i}.chunk'), "wb+") as writer:
45 for size, name, dtype, offsets in file:
46 name, part_num = (name, 0) if "_part" not in name else (name.split("_part")[0], int(name.split("_part")[1]))
47 default = {"parts": {}, "dtype": tinygrad_dtypes[dtype]}
48 weight_metadata = metadata.get(name, default)
49 weight_metadata["parts"][part_num] = {"file": i, "file_start_pos": cursor, "size": size}
50 metadata[name] = weight_metadata
51 data = bytes(state_dict[name].uop.base.realized.as_memoryview())
52 data = data if not offsets else data[offsets[0]:offsets[1]]
53 writer.write(data)
54 cursor += size
55
56 metadata.update({name: {"parts": {0: {"empty": True, "size": size}}, "dtype": tinygrad_dtypes[dtype]} for size, name, dtype in empty_t_infos})
57
58 for k in metadata:
59 metadata[k]["parts"] = [part for part_num, part in sorted(metadata[k]["parts"].items(), key = lambda x: x[0])]
60 cursor = 0
61 for i, part in enumerate(metadata[k]["parts"]):
62 metadata[k]["parts"][i]["target_start_pos"] = cursor
63 cursor += part["size"]
64 metadata[k]["size"] = cursor
65
66 # compute hashes, which client app will check to determine whether to update with new weights and/or detect integrity issues

Callers 1

compile.pyFile · 0.85

Calls 11

get_state_dictFunction · 0.90
appendMethod · 0.80
ceilMethod · 0.80
splitMethod · 0.80
as_memoryviewMethod · 0.80
getMethod · 0.45
writeMethod · 0.45
updateMethod · 0.45
encodeMethod · 0.45
readMethod · 0.45
addMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…