Split the elements of `o_vals` (observations values) into two groups: those for trials whose losses (`l_vals`) were above gamma, and those below gamma. Note that only unique elements are returned, so the total number of returned elements might be lower than `len(o_vals)`
(o_idxs, o_vals, l_idxs, l_vals, gamma, gamma_cap=DEFAULT_LF)
| 615 | |
| 616 | @scope.define_info(o_len=2) |
| 617 | def ap_split_trials(o_idxs, o_vals, l_idxs, l_vals, gamma, gamma_cap=DEFAULT_LF): |
| 618 | """Split the elements of `o_vals` (observations values) into two groups: those for |
| 619 | trials whose losses (`l_vals`) were above gamma, and those below gamma. Note that |
| 620 | only unique elements are returned, so the total number of returned elements might |
| 621 | be lower than `len(o_vals)` |
| 622 | """ |
| 623 | o_idxs, o_vals, l_idxs, l_vals = list( |
| 624 | map(np.asarray, [o_idxs, o_vals, l_idxs, l_vals]) |
| 625 | ) |
| 626 | |
| 627 | # XXX if this is working, refactor this sort for efficiency |
| 628 | |
| 629 | # Splitting is done this way to cope with duplicate loss values. |
| 630 | n_below = min(int(np.ceil(gamma * np.sqrt(len(l_vals)))), gamma_cap) |
| 631 | l_order = np.argsort(l_vals) |
| 632 | |
| 633 | keep_idxs = set(l_idxs[l_order[:n_below]]) |
| 634 | below = [v for i, v in zip(o_idxs, o_vals) if i in keep_idxs] |
| 635 | |
| 636 | keep_idxs = set(l_idxs[l_order[n_below:]]) |
| 637 | above = [v for i, v in zip(o_idxs, o_vals) if i in keep_idxs] |
| 638 | |
| 639 | return np.asarray(below), np.asarray(above) |
| 640 | |
| 641 | |
| 642 | @scope.define |