(data, alpha=0.2)
| 195 | |
| 196 | |
| 197 | def mixUpCls(data, alpha=0.2): |
| 198 | if alpha > 0: |
| 199 | lam = np.random.beta(alpha, alpha) |
| 200 | else: |
| 201 | lam = 1 |
| 202 | |
| 203 | batch_size = data.size(0) |
| 204 | index = torch.randperm(batch_size, device=data.device) |
| 205 | data_mixed = lam * data + (1 - lam) * data[index, :] |
| 206 | |
| 207 | return data_mixed, lam, index |
| 208 | |
| 209 | |
| 210 | if __name__ == '__main__': |