(args)
| 15 | |
| 16 | |
| 17 | def load_synthetic_data(args): |
| 18 | dataset_name = args.dataset |
| 19 | if dataset_name == "fake": |
| 20 | data_cache_dir = os.path.join(args.data_cache_dir, "fake_numeric_data") |
| 21 | if not os.path.exists(data_cache_dir): |
| 22 | os.makedirs(data_cache_dir, exist_ok=True) |
| 23 | generate_fake_data(data_cache_dir) |
| 24 | logging.info("load_data. dataset_name = %s" % dataset_name) |
| 25 | ( |
| 26 | datasize, |
| 27 | train_data_local_num_dict, |
| 28 | local_data_dict, |
| 29 | ) = load_partition_data_fake(data_dir=data_cache_dir, client_num=int(args.client_num_in_total)) |
| 30 | |
| 31 | dataset = [ |
| 32 | datasize, |
| 33 | train_data_local_num_dict, |
| 34 | local_data_dict, |
| 35 | ] |
| 36 | # print(f"datasize, train_data_local_num_dict, local_data_dict,{dataset}") |
| 37 | elif dataset_name == "twitter": |
| 38 | path = os.path.join(args.data_cache_dir, "twitter_Sentiment140") |
| 39 | download_twitter_Sentiment140(data_cache_dir=path) |
| 40 | if args.fa_task != FA_TASK_HEAVY_HITTER_TRIEHH: |
| 41 | local_datasets = preprocess_twitter_data(path=path) |
| 42 | ( |
| 43 | datasize, |
| 44 | train_data_local_num_dict, |
| 45 | local_data_dict, |
| 46 | ) = load_partition_data_twitter_sentiment140(local_datasets, client_num_in_total=int(args.client_num_in_total)) |
| 47 | |
| 48 | dataset = [ |
| 49 | datasize, |
| 50 | train_data_local_num_dict, |
| 51 | local_data_dict, |
| 52 | ] |
| 53 | else: |
| 54 | local_datasets = preprocess_twitter_data_heavy_hitter(path=path) |
| 55 | ( |
| 56 | datasize, |
| 57 | train_data_local_num_dict, |
| 58 | local_data_dict, |
| 59 | ) = load_partition_data_twitter_sentiment140_heavy_hitter(local_datasets, int(args.client_num_in_total)) |
| 60 | dataset = [ |
| 61 | datasize, |
| 62 | train_data_local_num_dict, |
| 63 | local_data_dict, |
| 64 | ] |
| 65 | elif dataset_name == "self_defined": |
| 66 | data_cache_dir = args.data_cache_dir |
| 67 | if not os.path.exists(data_cache_dir): |
| 68 | os.makedirs(data_cache_dir, exist_ok=True) |
| 69 | if hasattr(args, "data_col_idx") and isinstance(args.data_col_idx, int) and args.data_col_idx >= 0: |
| 70 | if hasattr(args, "seperator"): |
| 71 | separator = args.seperator |
| 72 | else: |
| 73 | separator = "," # default seperator = "," |
| 74 | ( |
no test coverage detected