Get dataset splits from comma separated string list
(args)
| 374 | |
| 375 | |
| 376 | def get_split(args): |
| 377 | """ |
| 378 | Get dataset splits from comma separated string list |
| 379 | """ |
| 380 | splits = [] |
| 381 | if args.split.find(',') != -1: |
| 382 | splits = [float(s) for s in args.split.split(',')] |
| 383 | elif args.split.find('/') != -1: |
| 384 | splits = [float(s) for s in args.split.split('/')] |
| 385 | else: |
| 386 | splits = [float(args.split)] |
| 387 | split_total = sum(splits) |
| 388 | if split_total < 1.: |
| 389 | splits.append(1 - split_total) |
| 390 | while len(splits) < 3: |
| 391 | splits.append(0.) |
| 392 | splits = splits[:3] |
| 393 | if args.valid_data is not None: |
| 394 | splits[1] = 0. |
| 395 | if args.test_data is not None: |
| 396 | splits[2] = 0. |
| 397 | final_sum = sum(splits) |
| 398 | return [s / final_sum for s in splits] |
| 399 | |
| 400 | |
| 401 | def configure_data(): |