Generator function to iterate over trainsets and testsets. Args: data(:obj:`Dataset `): The data containing ratings that will be divided into trainsets and testsets. Yields: tuple of (trainset, testset)
(self, data)
| 79 | self.random_state = random_state |
| 80 | |
| 81 | def split(self, data): |
| 82 | """Generator function to iterate over trainsets and testsets. |
| 83 | |
| 84 | Args: |
| 85 | data(:obj:`Dataset<surprise.dataset.Dataset>`): The data containing |
| 86 | ratings that will be divided into trainsets and testsets. |
| 87 | |
| 88 | Yields: |
| 89 | tuple of (trainset, testset) |
| 90 | """ |
| 91 | |
| 92 | if self.n_splits > len(data.raw_ratings) or self.n_splits < 2: |
| 93 | raise ValueError( |
| 94 | "Incorrect value for n_splits={}. " |
| 95 | "Must be >=2 and less than the number " |
| 96 | "of ratings".format(len(data.raw_ratings)) |
| 97 | ) |
| 98 | |
| 99 | # We use indices to avoid shuffling the original data.raw_ratings list. |
| 100 | indices = np.arange(len(data.raw_ratings)) |
| 101 | |
| 102 | if self.shuffle: |
| 103 | get_rng(self.random_state).shuffle(indices) |
| 104 | |
| 105 | start, stop = 0, 0 |
| 106 | for fold_i in range(self.n_splits): |
| 107 | start = stop |
| 108 | stop += len(indices) // self.n_splits |
| 109 | if fold_i < len(indices) % self.n_splits: |
| 110 | stop += 1 |
| 111 | |
| 112 | raw_trainset = [ |
| 113 | data.raw_ratings[i] for i in chain(indices[:start], indices[stop:]) |
| 114 | ] |
| 115 | raw_testset = [data.raw_ratings[i] for i in indices[start:stop]] |
| 116 | |
| 117 | trainset = data.construct_trainset(raw_trainset) |
| 118 | testset = data.construct_testset(raw_testset) |
| 119 | |
| 120 | yield trainset, testset |
| 121 | |
| 122 | def get_n_folds(self): |
| 123 |