| 49 | del self.inputs[k] |
| 50 | |
| 51 | def load(self, input_fn): |
| 52 | float32 = not FLOAT16 |
| 53 | |
| 54 | mf = cl.mem_flags |
| 55 | image_fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT if float32 else cl.channel_type.HALF_FLOAT) |
| 56 | image_fmt_32 = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.FLOAT) |
| 57 | |
| 58 | with open(input_fn, "rb") as f: |
| 59 | json_len = struct.unpack("I", f.read(4))[0] |
| 60 | jdat = json.loads(f.read(json_len).decode('latin_1')) |
| 61 | weights = f.read() |
| 62 | |
| 63 | # load in the buffers |
| 64 | bufs = {'\x00\x00\x00\x00\x00\x00\x00\x00': None} |
| 65 | bufs_loaded = {} |
| 66 | ptr = 0 |
| 67 | for o in jdat['objects']: |
| 68 | #print(o) |
| 69 | if o['needs_load']: |
| 70 | nptr = ptr + o['size'] |
| 71 | o['data'] = weights[ptr:nptr] |
| 72 | ptr = nptr |
| 73 | |
| 74 | if o['arg_type'] == "image2d_t" or o['arg_type'] == "image1d_t": |
| 75 | tfmt = image_fmt_32 if 'float32' in o and o['float32'] else image_fmt |
| 76 | if o['arg_type'] == "image2d_t": |
| 77 | if 'buffer_id' in o and o['height'] == 1 and not bufs_loaded[o['buffer_id']]: |
| 78 | # hack: use a image1d since we can back that with a buffer |
| 79 | buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']]) |
| 80 | else: |
| 81 | # buffer isn't supported in image2d, copy buffer into image |
| 82 | if 'buffer_id' in o and bufs_loaded[o['buffer_id']]: |
| 83 | arr = np.zeros(bufs[o['buffer_id']].size // 2, dtype=np.float16) |
| 84 | cl.enqueue_copy(CL.queue, arr, bufs[o['buffer_id']]) |
| 85 | buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt, |
| 86 | shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=arr) |
| 87 | elif o['needs_load']: |
| 88 | buf = cl.Image(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, tfmt, |
| 89 | shape=(o['width'], o['height']), pitches=(o['row_pitch'],), hostbuf=o['data']) |
| 90 | else: |
| 91 | buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'], o['height'])) |
| 92 | if o['arg_type'] == "image1d_t": |
| 93 | assert not o['needs_load'] |
| 94 | assert not bufs_loaded[o['buffer_id']] |
| 95 | buf = cl.Image(CL.ctx, mf.READ_WRITE, tfmt, shape=(o['width'],), buffer=bufs[o['buffer_id']]) |
| 96 | else: |
| 97 | if 'data' in o: |
| 98 | buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=o['data']) |
| 99 | else: |
| 100 | # zero out buffers |
| 101 | buf = cl.Buffer(CL.ctx, mf.READ_WRITE | mf.COPY_HOST_PTR, hostbuf=b'\x00'*o['size']) |
| 102 | |
| 103 | bufs[o['id']] = buf |
| 104 | bufs_loaded[o['id']] = 'data' in o |
| 105 | # if it's loaded, it's saved |
| 106 | if 'data' in o: |
| 107 | self.buffers_to_save.add(buf) |
| 108 | |