(
self,
data_loader: DataLoader,
*,
poisson_sampling: bool,
distributed: bool,
batch_first: bool = True,
rand_on_empty: bool = False,
)
| 142 | ) |
| 143 | |
| 144 | def _prepare_data_loader( |
| 145 | self, |
| 146 | data_loader: DataLoader, |
| 147 | *, |
| 148 | poisson_sampling: bool, |
| 149 | distributed: bool, |
| 150 | batch_first: bool = True, |
| 151 | rand_on_empty: bool = False, |
| 152 | ) -> DataLoader: |
| 153 | if self.dataset is None: |
| 154 | self.dataset = data_loader.dataset |
| 155 | elif self.dataset != data_loader.dataset: |
| 156 | warnings.warn( |
| 157 | f"PrivacyEngine detected new dataset object. " |
| 158 | f"Was: {self.dataset}, got: {data_loader.dataset}. " |
| 159 | f"Privacy accounting works per dataset, please initialize " |
| 160 | f"new PrivacyEngine if you're using different dataset. " |
| 161 | f"You can ignore this warning if two datasets above " |
| 162 | f"represent the same logical dataset" |
| 163 | ) |
| 164 | |
| 165 | if poisson_sampling: |
| 166 | return DPDataLoader.from_data_loader( |
| 167 | data_loader, |
| 168 | generator=self.secure_rng, |
| 169 | distributed=distributed, |
| 170 | batch_first=batch_first, |
| 171 | rand_on_empty=rand_on_empty, |
| 172 | ) |
| 173 | elif self.secure_mode: |
| 174 | return switch_generator(data_loader=data_loader, generator=self.secure_rng) |
| 175 | else: |
| 176 | return data_loader |
| 177 | |
| 178 | def _prepare_model( |
| 179 | self, |
no test coverage detected