Interpolate or mark bads consistently for a list of instances. Once called on a list of instances, the instances can be concatenated as they will have the same list of bad channels. Parameters ---------- insts : list The list of instances (Evoked, Epochs or Raw) to cons
(insts, interp_thresh=1.0, copy=True)
| 18 | |
| 19 | |
| 20 | def equalize_bads(insts, interp_thresh=1.0, copy=True): |
| 21 | """Interpolate or mark bads consistently for a list of instances. |
| 22 | |
| 23 | Once called on a list of instances, the instances can be concatenated |
| 24 | as they will have the same list of bad channels. |
| 25 | |
| 26 | Parameters |
| 27 | ---------- |
| 28 | insts : list |
| 29 | The list of instances (Evoked, Epochs or Raw) to consider |
| 30 | for interpolation. Each instance should have marked channels. |
| 31 | interp_thresh : float |
| 32 | A float between 0 and 1 (default) that specifies the fraction of time |
| 33 | a channel should be good to be eventually interpolated for certain |
| 34 | instances. For example if 0.5, a channel which is good at least half |
| 35 | of the time will be interpolated in the instances where it is marked |
| 36 | as bad. If 1 then channels will never be interpolated and if 0 all bad |
| 37 | channels will be systematically interpolated. |
| 38 | copy : bool |
| 39 | If True then the returned instances will be copies. |
| 40 | |
| 41 | Returns |
| 42 | ------- |
| 43 | insts_bads : list |
| 44 | The list of instances, with the same channel(s) marked as bad in all of |
| 45 | them, possibly with some formerly bad channels interpolated. |
| 46 | """ |
| 47 | if not 0 <= interp_thresh <= 1: |
| 48 | raise ValueError(f"interp_thresh must be between 0 and 1, got {interp_thresh}") |
| 49 | |
| 50 | all_bads = list(set(chain.from_iterable([inst.info["bads"] for inst in insts]))) |
| 51 | if isinstance(insts[0], BaseEpochs): |
| 52 | durations = [len(inst) * len(inst.times) for inst in insts] |
| 53 | else: |
| 54 | durations = [len(inst.times) for inst in insts] |
| 55 | |
| 56 | good_times = [] |
| 57 | for ch_name in all_bads: |
| 58 | good_times.append( |
| 59 | sum( |
| 60 | durations[k] |
| 61 | for k, inst in enumerate(insts) |
| 62 | if ch_name not in inst.info["bads"] |
| 63 | ) |
| 64 | / np.sum(durations) |
| 65 | ) |
| 66 | |
| 67 | bads_keep = [ch for k, ch in enumerate(all_bads) if good_times[k] < interp_thresh] |
| 68 | if copy: |
| 69 | insts = [inst.copy() for inst in insts] |
| 70 | |
| 71 | for inst in insts: |
| 72 | if len(set(inst.info["bads"]) - set(bads_keep)): |
| 73 | inst.interpolate_bads(exclude=bads_keep) |
| 74 | inst.info["bads"] = bads_keep |
| 75 | |
| 76 | return insts |
| 77 |