(args)
| 24 | |
| 25 | |
| 26 | def load_synthetic_data(args): |
| 27 | dataset_name = args.dataset |
| 28 | # check if the centralized training is enabled |
| 29 | centralized = ( |
| 30 | True |
| 31 | if (args.client_num_in_total == 1 and args.training_type != "cross_silo") |
| 32 | else False |
| 33 | ) |
| 34 | |
| 35 | # check if the full-batch training is enabled |
| 36 | args_batch_size = args.batch_size |
| 37 | if args.batch_size <= 0: |
| 38 | full_batch = True |
| 39 | args.batch_size = 128 # temporary batch size |
| 40 | else: |
| 41 | full_batch = False |
| 42 | logging.info("load_data. dataset_name = %s" % dataset_name) |
| 43 | attributes = BaseDataManager.load_attributes(args.data_file_path) |
| 44 | # num_labels = len(attributes["label_vocab"]) |
| 45 | # class_num = num_labels |
| 46 | class_num = -1 |
| 47 | model_args = Seq2SeqArgs() |
| 48 | model_args.model_name = args.model |
| 49 | model_args.model_type = args.model_type |
| 50 | # model_args.load(model_args.model_name) |
| 51 | # model_args.num_labels = num_labels |
| 52 | model_args.update_from_dict( |
| 53 | { |
| 54 | "fl_algorithm": args.federated_optimizer, |
| 55 | "freeze_layers": args.freeze_layers, |
| 56 | "epochs": args.epochs, |
| 57 | "learning_rate": args.learning_rate, |
| 58 | "gradient_accumulation_steps": args.gradient_accumulation_steps, |
| 59 | "do_lower_case": args.do_lower_case, |
| 60 | "manual_seed": args.random_seed, |
| 61 | # for ignoring the cache features. |
| 62 | "reprocess_input_data": args.reprocess_input_data, |
| 63 | "overwrite_output_dir": True, |
| 64 | "max_seq_length": args.max_seq_length, |
| 65 | "train_batch_size": args.batch_size, |
| 66 | "eval_batch_size": args.eval_batch_size, |
| 67 | "evaluate_during_training": False, # Disabled for FedAvg. |
| 68 | "evaluate_during_training_steps": args.evaluate_during_training_steps, |
| 69 | "fp16": args.fp16, |
| 70 | "data_file_path": args.data_file_path, |
| 71 | "partition_file_path": args.partition_file_path, |
| 72 | "partition_method": args.partition_method, |
| 73 | "dataset": args.dataset, |
| 74 | "output_dir": args.output_dir, |
| 75 | "is_debug_mode": args.is_debug_mode, |
| 76 | "fedprox_mu": args.fedprox_mu, |
| 77 | } |
| 78 | ) |
| 79 | |
| 80 | # model_args.config["num_labels"] = num_labels |
| 81 | tokenizer_class = BartTokenizer |
| 82 | tokenizer = [None, None] |
| 83 | tokenizer[0] = tokenizer_class.from_pretrained(args.model) |
no test coverage detected