(args)
| 490 | |
| 491 | |
| 492 | def main(args): |
| 493 | max_seq_length = args.max_seq_length |
| 494 | max_predictions_per_seq = args.max_predictions_per_seq |
| 495 | masked_lm_prob = args.masked_lm_prob |
| 496 | mask_prob = args.mask_prob |
| 497 | dupe_factor = args.dupe_factor |
| 498 | prop_sliding_window = args.prop_sliding_window |
| 499 | pool_size = args.pool_size |
| 500 | |
| 501 | output_dir = args.data_dir |
| 502 | dataset_name = args.dataset_name |
| 503 | |
| 504 | if not os.path.isdir(output_dir): |
| 505 | print(output_dir + ' is not exist') |
| 506 | print(os.getcwd()) |
| 507 | exit(1) |
| 508 | os.mkdir(output_dir + 'train') |
| 509 | os.mkdir(output_dir + 'test') |
| 510 | |
| 511 | dataset = data_partition(output_dir + dataset_name + '.txt') |
| 512 | [user_train, user_valid, user_test, usernum, itemnum] = dataset |
| 513 | cc = 0.0 |
| 514 | max_len = 0 |
| 515 | min_len = 100000 |
| 516 | for u in user_train: |
| 517 | cc += len(user_train[u]) |
| 518 | max_len = max(len(user_train[u]), max_len) |
| 519 | min_len = min(len(user_train[u]), min_len) |
| 520 | |
| 521 | print('average sequence length: %.2f' % (cc / len(user_train))) |
| 522 | print('max:{}, min:{}'.format(max_len, min_len)) |
| 523 | |
| 524 | print('len_train:{}, len_valid:{}, len_test:{}, usernum:{}, itemnum:{}'. |
| 525 | format( |
| 526 | len(user_train), |
| 527 | len(user_valid), len(user_test), usernum, itemnum)) |
| 528 | |
| 529 | for idx, u in enumerate(user_train): |
| 530 | if idx <= 1: |
| 531 | print(user_train[u]) |
| 532 | print(user_valid[u]) |
| 533 | print(user_test[u]) |
| 534 | |
| 535 | # put validate into train |
| 536 | for u in user_train: |
| 537 | if u in user_valid: |
| 538 | user_train[u].extend(user_valid[u]) |
| 539 | |
| 540 | # get the max index of the data |
| 541 | user_train_data = { |
| 542 | 'user_' + str(k): ['item_' + str(item) for item in v] |
| 543 | for k, v in user_train.items() if len(v) > 0 |
| 544 | } |
| 545 | user_test_data = { |
| 546 | 'user_' + str(u): |
| 547 | ['item_' + str(item) for item in (user_train[u] + user_test[u])] |
| 548 | for u in user_train if len(user_train[u]) > 0 and len(user_test[u]) > 0 |
| 549 | } |
no test coverage detected