(args, model_cfg)
| 925 | |
| 926 | |
| 927 | def get_data(args, model_cfg): |
| 928 | data = {} |
| 929 | |
| 930 | args.class_index_dict = load_class_label(args.class_label_path) |
| 931 | |
| 932 | if args.datasetinfos is None: |
| 933 | args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] |
| 934 | if args.dataset_type == "webdataset": |
| 935 | args.train_data = get_tar_path_from_dataset_name( |
| 936 | args.datasetnames, |
| 937 | args.datasetinfos, |
| 938 | islocal=not args.remotedata, |
| 939 | proportion=args.dataset_proportion, |
| 940 | dataset_path=args.datasetpath, |
| 941 | full_dataset=args.full_train_dataset, |
| 942 | ) |
| 943 | |
| 944 | if args.full_train_dataset is None: |
| 945 | args.full_train_dataset = [] |
| 946 | if args.exclude_eval_dataset is None: |
| 947 | args.exclude_eval_dataset = [] |
| 948 | excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset |
| 949 | |
| 950 | val_dataset_names = ( |
| 951 | [n for n in args.datasetnames if n not in excluded_eval_datasets] |
| 952 | if excluded_eval_datasets |
| 953 | else args.datasetnames |
| 954 | ) |
| 955 | args.val_dataset_names = val_dataset_names |
| 956 | args.val_data = get_tar_path_from_dataset_name( |
| 957 | val_dataset_names, |
| 958 | ["valid", "test", "eval"], |
| 959 | islocal=not args.remotedata, |
| 960 | proportion=1, |
| 961 | dataset_path=args.datasetpath, |
| 962 | full_dataset=None, |
| 963 | ) |
| 964 | |
| 965 | if args.train_data: |
| 966 | data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( |
| 967 | args, model_cfg, is_train=True |
| 968 | ) |
| 969 | |
| 970 | if args.val_data: |
| 971 | data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( |
| 972 | args, model_cfg, is_train=False |
| 973 | ) |
| 974 | |
| 975 | return data |
no test coverage detected