MCPcopy
hub / github.com/FedML-AI/FedML / load_synthetic_data

Function load_synthetic_data

python/app/fednlp/span_extraction/data/data_loader.py:31–155  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

29
30
31def load_synthetic_data(args):
32 dataset_name = args.dataset
33 # check if the centralized training is enabled
34 centralized = (
35 True
36 if (args.client_num_in_total == 1 and args.training_type != "cross_silo")
37 else False
38 )
39
40 # check if the full-batch training is enabled
41 args_batch_size = args.batch_size
42 if args.batch_size <= 0:
43 full_batch = True
44 args.batch_size = 128 # temporary batch size
45 else:
46 full_batch = False
47 logging.info("load_data. dataset_name = %s" % dataset_name)
48 attributes = BaseDataManager.load_attributes(args.data_file_path)
49 # num_labels = len(attributes["label_vocab"])
50 # class_num = num_labels
51 class_num = -1
52 model_args = SpanExtractionArgs()
53 model_args.model_name = args.model
54 model_args.model_type = args.model_type
55 # model_args.load(model_args.model_name)
56 # model_args.num_labels = num_labels
57 model_args.update_from_dict(
58 {
59 "fl_algorithm": args.federated_optimizer,
60 "freeze_layers": args.freeze_layers,
61 "epochs": args.epochs,
62 "learning_rate": args.learning_rate,
63 "gradient_accumulation_steps": args.gradient_accumulation_steps,
64 "do_lower_case": args.do_lower_case,
65 "manual_seed": args.random_seed,
66 # for ignoring the cache features.
67 "reprocess_input_data": args.reprocess_input_data,
68 "overwrite_output_dir": True,
69 "max_seq_length": args.max_seq_length,
70 "train_batch_size": args.batch_size,
71 "eval_batch_size": args.eval_batch_size,
72 "evaluate_during_training": False, # Disabled for FedAvg.
73 "evaluate_during_training_steps": args.evaluate_during_training_steps,
74 "fp16": args.fp16,
75 "data_file_path": args.data_file_path,
76 "partition_file_path": args.partition_file_path,
77 "partition_method": args.partition_method,
78 "dataset": args.dataset,
79 "output_dir": args.output_dir,
80 "is_debug_mode": args.is_debug_mode,
81 "fedprox_mu": args.fedprox_mu,
82 }
83 )
84
85 # model_args.config["num_labels"] = num_labels
86 if args.model_type == "bert":
87 tokenizer_class = BertTokenizer
88 elif args.model_type == "distilbert":

Callers 1

loadFunction · 0.70

Calls 10

SpanExtractionArgsClass · 0.85
load_attributesMethod · 0.80
update_from_dictMethod · 0.80
load_federated_dataMethod · 0.80
valuesMethod · 0.80
keysMethod · 0.80
combine_batchesFunction · 0.70
infoMethod · 0.45
from_pretrainedMethod · 0.45

Tested by

no test coverage detected