| 120 | |
| 121 | |
| 122 | def reshape_data(x, hue, labels): |
| 123 | categories = list(sorted(set(hue), key=list(hue).index)) |
| 124 | x_stacked = np.vstack(x) |
| 125 | x_reshaped = [[] for _ in categories] |
| 126 | labels_reshaped = [[] for _ in categories] |
| 127 | if labels is None: |
| 128 | labels = [None]*len(hue) |
| 129 | for idx, (point, label) in enumerate(zip(hue, labels)): |
| 130 | x_reshaped[categories.index(point)].append(x_stacked[idx]) |
| 131 | labels_reshaped[categories.index(point)].append(labels[idx]) |
| 132 | return [np.vstack(i) for i in x_reshaped], labels_reshaped |
| 133 | |
| 134 | |
| 135 | def patch_lines(x): |