MCPcopy
hub / github.com/FedML-AI/FedML / load_synthetic_data

Function load_synthetic_data

python/app/fednlp/seq2seq/data/data_loader.py:26–145  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

24
25
26def load_synthetic_data(args):
27 dataset_name = args.dataset
28 # check if the centralized training is enabled
29 centralized = (
30 True
31 if (args.client_num_in_total == 1 and args.training_type != "cross_silo")
32 else False
33 )
34
35 # check if the full-batch training is enabled
36 args_batch_size = args.batch_size
37 if args.batch_size <= 0:
38 full_batch = True
39 args.batch_size = 128 # temporary batch size
40 else:
41 full_batch = False
42 logging.info("load_data. dataset_name = %s" % dataset_name)
43 attributes = BaseDataManager.load_attributes(args.data_file_path)
44 # num_labels = len(attributes["label_vocab"])
45 # class_num = num_labels
46 class_num = -1
47 model_args = Seq2SeqArgs()
48 model_args.model_name = args.model
49 model_args.model_type = args.model_type
50 # model_args.load(model_args.model_name)
51 # model_args.num_labels = num_labels
52 model_args.update_from_dict(
53 {
54 "fl_algorithm": args.federated_optimizer,
55 "freeze_layers": args.freeze_layers,
56 "epochs": args.epochs,
57 "learning_rate": args.learning_rate,
58 "gradient_accumulation_steps": args.gradient_accumulation_steps,
59 "do_lower_case": args.do_lower_case,
60 "manual_seed": args.random_seed,
61 # for ignoring the cache features.
62 "reprocess_input_data": args.reprocess_input_data,
63 "overwrite_output_dir": True,
64 "max_seq_length": args.max_seq_length,
65 "train_batch_size": args.batch_size,
66 "eval_batch_size": args.eval_batch_size,
67 "evaluate_during_training": False, # Disabled for FedAvg.
68 "evaluate_during_training_steps": args.evaluate_during_training_steps,
69 "fp16": args.fp16,
70 "data_file_path": args.data_file_path,
71 "partition_file_path": args.partition_file_path,
72 "partition_method": args.partition_method,
73 "dataset": args.dataset,
74 "output_dir": args.output_dir,
75 "is_debug_mode": args.is_debug_mode,
76 "fedprox_mu": args.fedprox_mu,
77 }
78 )
79
80 # model_args.config["num_labels"] = num_labels
81 tokenizer_class = BartTokenizer
82 tokenizer = [None, None]
83 tokenizer[0] = tokenizer_class.from_pretrained(args.model)

Callers 1

loadFunction · 0.70

Calls 10

Seq2SeqArgsClass · 0.85
Seq2SeqDataManagerClass · 0.85
load_attributesMethod · 0.80
update_from_dictMethod · 0.80
load_federated_dataMethod · 0.80
valuesMethod · 0.80
keysMethod · 0.80
combine_batchesFunction · 0.70
infoMethod · 0.45
from_pretrainedMethod · 0.45

Tested by

no test coverage detected