MCPcopy
hub / github.com/hpcaitech/ColossalAI / main

Function main

examples/tutorial/opt/opt/run_clm.py:289–668  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

287
288
289def main():
290 args = parse_args()
291 disable_existing_loggers()
292 colossalai.legacy.launch_from_torch()
293 logger = get_dist_logger()
294 is_main_process = dist.get_rank() == 0
295
296 if is_main_process:
297 datasets.utils.logging.set_verbosity_warning()
298 logging.set_verbosity_info()
299 else:
300 datasets.utils.logging.set_verbosity_error()
301 logging.set_verbosity_error()
302
303 if args.mem_cap > 0:
304 colo_memory_cap(args.mem_cap)
305
306 # If passed along, set the training seed now.
307 if args.seed is not None:
308 set_seed(args.seed)
309 logger.info(f"Rank {dist.get_rank()}: random seed is set to {args.seed}")
310
311 # Handle the repository creation
312 with barrier_context():
313 if args.output_dir is not None:
314 os.makedirs(args.output_dir, exist_ok=True)
315
316 # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
317 # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
318 # (the dataset will be downloaded automatically from the datasets Hub).
319 #
320 # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
321 # 'text' is found. You can easily tweak this behavior (see below).
322 #
323 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
324 # download the dataset.
325 logger.info("Start preparing dataset", ranks=[0])
326 if not args.synthetic:
327 if args.dataset_name is not None:
328 # Downloading and loading a dataset from the hub.
329 raw_datasets = load_dataset(args.dataset_name, args.dataset_config_name)
330 if "validation" not in raw_datasets.keys():
331 raw_datasets["validation"] = load_dataset(
332 args.dataset_name,
333 args.dataset_config_name,
334 split=f"train[:{args.validation_split_percentage}%]",
335 )
336 raw_datasets["train"] = load_dataset(
337 args.dataset_name,
338 args.dataset_config_name,
339 split=f"train[{args.validation_split_percentage}%:]",
340 )
341 else:
342 data_files = {}
343 dataset_args = {}
344 if args.train_file is not None:
345 data_files["train"] = args.train_file
346 if args.validation_file is not None:

Callers 1

run_clm.pyFile · 0.70

Calls 15

backwardMethod · 0.95
stepMethod · 0.95
zero_gradMethod · 0.95
disable_existing_loggersFunction · 0.90
get_dist_loggerFunction · 0.90
set_seedFunction · 0.90
barrier_contextClass · 0.90
get_acceleratorFunction · 0.90
LazyInitContextClass · 0.90
ColoInitContextClass · 0.90
GeminiDDPClass · 0.90
ProcessGroupClass · 0.90

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…