MCPcopy
hub / github.com/tinygrad/tinygrad / load

Method load

extra/thneed.py:51–143  ·  view source on GitHub ↗
(self, input_fn)

Source from the content-addressed store, hash-verified

49 del self.inputs[k]
50
51 def load(self, input_fn):
52 float32 = not FLOAT16
53
54 mf = cl.mem_flags
55 image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT)
56 image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT)
57
58 with open(input_fn, "rb") as f:
59 json_len = struct.unpack("I", f.read(4))[0]
60 jdat = json.loads(f.read(json_len).decode('latin_1'))
61 weights = f.read()
62
63 # load in the buffers
64 bufs = {'\x00\x00\x00\x00\x00\x00\x00\x00': None}
65 bufs_loaded = {}
66 ptr = 0
67 for o in jdat['objects']:
68 #print(o)
69 if o['needs_load']:
70 nptr = ptr + o['size']
71 o['data'] = weights[ptr:nptr]
72 ptr = nptr
73
74 if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t":
75 tfmt = image_fmt_32 if 'float32' in o and o['float32'] else image_fmt
76 if o['arg_type'] == "image2d_t":
77 if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]:
78 # hack: use a image1d since we can back that with a buffer
79 buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
80 else:
81 # buffer isn't supported in image2d, copy buffer into image
82 if 'buffer_id' in o and bufs_loaded[o['buffer_id']]:
83 arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16)
84 cl.enqueue_copy(CL.queue, arr, bufs[o['buffer_id']])
85 buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
86 shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr)
87 elif o['needs_load']:
88 buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt,
89 shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data'])
90 else:
91 buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height']))
92 if o['arg_type'] == "image1d_t":
93 assert not o['needs_load']
94 assert not bufs_loaded[o['buffer_id']]
95 buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']])
96 else:
97 if 'data' in o:
98 buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data'])
99 else:
100 # zero out buffers
101 buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size'])
102
103 bufs[o['id']] = buf
104 bufs_loaded[o['id']] = 'data' in o
105 # if it's loaded, it's saved
106 if 'data' in o:
107 self.buffers_to_save.add(buf)
108

Callers 13

loadFunction · 0.45
loadFunction · 0.45
load_file_waveformFunction · 0.45
vgg7.pyFile · 0.45
load_pickle.pyFile · 0.45
test_vs_onnxFunction · 0.45
compile3.pyFile · 0.45
load_fileFunction · 0.45
load_unet3d_dataFunction · 0.45
__init__Method · 0.45

Calls 8

CLProgramClass · 0.90
zerosMethod · 0.80
appendMethod · 0.80
keysMethod · 0.80
readMethod · 0.45
decodeMethod · 0.45
addMethod · 0.45
encodeMethod · 0.45

Tested by 1

test_vs_onnxFunction · 0.36