Get dataset splits from comma separated string list
(args)
| 202 | return (train, valid, test), tokenizer |
| 203 | |
| 204 | def get_split(args): |
| 205 | """ |
| 206 | Get dataset splits from comma separated string list |
| 207 | """ |
| 208 | splits = [] |
| 209 | if args.split.find(',') != -1: |
| 210 | splits = [float(s) for s in args.split.split(',')] |
| 211 | elif args.split.find('/') != -1: |
| 212 | splits = [float(s) for s in args.split.split('/')] |
| 213 | else: |
| 214 | splits = [float(args.split)] |
| 215 | split_total = sum(splits) |
| 216 | if split_total < 1.: |
| 217 | splits.append(1-split_total) |
| 218 | while len(splits) < 3: |
| 219 | splits.append(0.) |
| 220 | splits = splits[:3] |
| 221 | if args.valid_data is not None: |
| 222 | splits[1] = 0. |
| 223 | if args.test_data is not None: |
| 224 | splits[2] = 0. |
| 225 | final_sum = sum(splits) |
| 226 | return [s/final_sum for s in splits] |
| 227 | |
| 228 | def configure_data(): |
| 229 |