MCPcopy
hub / github.com/FedML-AI/FedML / load_synthetic_data

Function load_synthetic_data

python/fedml/fa/data/data_loader.py:17–92  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

15
16
17def load_synthetic_data(args):
18 dataset_name = args.dataset
19 if dataset_name == "fake":
20 data_cache_dir = os.path.join(args.data_cache_dir, "fake_numeric_data")
21 if not os.path.exists(data_cache_dir):
22 os.makedirs(data_cache_dir, exist_ok=True)
23 generate_fake_data(data_cache_dir)
24 logging.info("load_data. dataset_name = %s" % dataset_name)
25 (
26 datasize,
27 train_data_local_num_dict,
28 local_data_dict,
29 ) = load_partition_data_fake(data_dir=data_cache_dir, client_num=int(args.client_num_in_total))
30
31 dataset = [
32 datasize,
33 train_data_local_num_dict,
34 local_data_dict,
35 ]
36 # print(f"datasize, train_data_local_num_dict, local_data_dict,{dataset}")
37 elif dataset_name == "twitter":
38 path = os.path.join(args.data_cache_dir, "twitter_Sentiment140")
39 download_twitter_Sentiment140(data_cache_dir=path)
40 if args.fa_task != FA_TASK_HEAVY_HITTER_TRIEHH:
41 local_datasets = preprocess_twitter_data(path=path)
42 (
43 datasize,
44 train_data_local_num_dict,
45 local_data_dict,
46 ) = load_partition_data_twitter_sentiment140(local_datasets, client_num_in_total=int(args.client_num_in_total))
47
48 dataset = [
49 datasize,
50 train_data_local_num_dict,
51 local_data_dict,
52 ]
53 else:
54 local_datasets = preprocess_twitter_data_heavy_hitter(path=path)
55 (
56 datasize,
57 train_data_local_num_dict,
58 local_data_dict,
59 ) = load_partition_data_twitter_sentiment140_heavy_hitter(local_datasets, int(args.client_num_in_total))
60 dataset = [
61 datasize,
62 train_data_local_num_dict,
63 local_data_dict,
64 ]
65 elif dataset_name == "self_defined":
66 data_cache_dir = args.data_cache_dir
67 if not os.path.exists(data_cache_dir):
68 os.makedirs(data_cache_dir, exist_ok=True)
69 if hasattr(args, "data_col_idx") and isinstance(args.data_col_idx, int) and args.data_col_idx >= 0:
70 if hasattr(args, "seperator"):
71 separator = args.seperator
72 else:
73 separator = "," # default seperator = ","
74 (

Callers 2

fa_load_dataFunction · 0.70
load_synthetic_data_testFunction · 0.70

Tested by

no test coverage detected