MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

examples/amused/train_amused.py:419–943  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

417
418
419def main(args):
420 if args.allow_tf32:
421 torch.backends.cuda.matmul.allow_tf32 = True
422
423 logging_dir = Path(args.output_dir, args.logging_dir)
424
425 accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
426
427 accelerator = Accelerator(
428 gradient_accumulation_steps=args.gradient_accumulation_steps,
429 mixed_precision=args.mixed_precision,
430 log_with=args.report_to,
431 project_config=accelerator_project_config,
432 )
433 # Disable AMP for MPS.
434 if torch.backends.mps.is_available():
435 accelerator.native_amp = False
436
437 if accelerator.is_main_process:
438 os.makedirs(args.output_dir, exist_ok=True)
439
440 # Make one log on every process with the configuration for debugging.
441 logging.basicConfig(
442 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
443 datefmt="%m/%d/%Y %H:%M:%S",
444 level=logging.INFO,
445 )
446 logger.info(accelerator.state, main_process_only=False)
447
448 if accelerator.is_main_process:
449 accelerator.init_trackers("amused", config=vars(copy.deepcopy(args)))
450
451 if args.seed is not None:
452 set_seed(args.seed)
453
454 # TODO - will have to fix loading if training text encoder
455 text_encoder = CLIPTextModelWithProjection.from_pretrained(
456 args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
457 )
458 tokenizer = CLIPTokenizer.from_pretrained(
459 args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant
460 )
461 vq_model = VQModel.from_pretrained(
462 args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant
463 )
464
465 if args.train_text_encoder:
466 if args.text_encoder_use_lora:
467 lora_config = LoraConfig(
468 r=args.text_encoder_lora_r,
469 lora_alpha=args.text_encoder_lora_alpha,
470 target_modules=args.text_encoder_lora_target_modules,
471 )
472 text_encoder.add_adapter(lora_config)
473 text_encoder.train()
474 text_encoder.requires_grad_(True)
475 else:
476 text_encoder.eval()

Callers 1

train_amused.pyFile · 0.70

Calls 15

toMethod · 0.95
stepMethod · 0.95
storeMethod · 0.95
copy_toMethod · 0.95
restoreMethod · 0.95
set_seedFunction · 0.90
EMAModelClass · 0.90
AmusedPipelineClass · 0.90
HuggingFaceDatasetClass · 0.85
load_datasetFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…