| 41 | |
| 42 | @staticmethod |
| 43 | def to_word_list_format(word_dict, tokenizer): |
| 44 | flat_ids = [] |
| 45 | offsets = [] |
| 46 | for word_dict_item in word_dict: |
| 47 | item_flat_ids = [] |
| 48 | item_offsets = [] |
| 49 | |
| 50 | for word in word_dict_item: |
| 51 | ids = tokenizer.encode(word).ids |
| 52 | |
| 53 | if len(ids) == 0: |
| 54 | continue |
| 55 | |
| 56 | item_flat_ids += ids |
| 57 | item_offsets.append(len(ids)) |
| 58 | |
| 59 | # Hack, can we do this better? |
| 60 | if word == '\n\n': |
| 61 | item_flat_ids += [198, 198] |
| 62 | item_offsets.append(2) |
| 63 | |
| 64 | flat_ids.append(np.array(item_flat_ids)) |
| 65 | offsets.append(np.cumsum(np.array(item_offsets))) |
| 66 | |
| 67 | pad_to = max(1, max(len(ids) for ids in flat_ids)) |
| 68 | |
| 69 | for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): |
| 70 | flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) |
| 71 | offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) |
| 72 | |
| 73 | return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) |
| 74 | |
| 75 | def generate(self, data): |
| 76 | prompt = data['prompt'] |