MCPcopy
hub / github.com/InternLM/InternLM / evaluate_on_val_dls

Function evaluate_on_val_dls

internlm/utils/evaluation.py:63–168  ·  view source on GitHub ↗
(
    trainer,
    val_dls,
    writer,
    logger,
    step_count,
    update_panel: bool = False,
    streaming: bool = False,
)

Source from the content-addressed store, hash-verified

61
62
63def evaluate_on_val_dls(
64 trainer,
65 val_dls,
66 writer,
67 logger,
68 step_count,
69 update_panel: bool = False,
70 streaming: bool = False,
71):
72 with switch_sequence_parallel_mode():
73 torch.cuda.empty_cache()
74 trainer.eval()
75 verbose = gpc.is_rank_for_log()
76 data_cfg = gpc.config.data
77
78 for val_name, val_dl in val_dls.items():
79 if len(val_dl) == 0 and verbose and not streaming:
80 logger.info(f"Validation dataset: {val_name} is empty")
81 continue
82
83 val_metric = AccPerplex(
84 device=torch.cuda.current_device(),
85 tp_pg=gpc.get_group(ParallelMode.TENSOR),
86 dp_pg=gpc.get_group(ParallelMode.DATA),
87 )
88 val_sche_metric_hook = SchedulerMetricHook(metric=val_metric)
89
90 val_loss = 0
91 val_idx = -1
92 for val_idx, batch in tqdm(
93 enumerate(val_dl),
94 desc="Val.",
95 total=len(val_dl) if not streaming else None,
96 position=1,
97 disable=not verbose,
98 leave=False,
99 ):
100 with torch.inference_mode():
101 if gpc.is_using_pp():
102 total_val_bsz = len(batch[1])
103 assert total_val_bsz % data_cfg.micro_bsz == 0
104 num_microbatches = total_val_bsz // data_cfg.micro_bsz
105 tensor_shape = torch.Size(
106 [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
107 )
108
109 with switch_evaluation_pipeline_scheduler(
110 trainer=trainer,
111 num_microbatches=num_microbatches,
112 tensor_shape=tensor_shape,
113 metric_hook_list=[val_sche_metric_hook],
114 ):
115 _, _, loss = trainer.execute_schedule(
116 batch, forward_only=True, return_loss=True, return_output_label=False
117 )
118 else:
119 total_val_bsz = len(batch[1])
120 assert total_val_bsz % data_cfg.micro_bsz == 0

Callers 1

mainFunction · 0.90

Calls 13

get_metricMethod · 0.95
AccPerplexClass · 0.90
SchedulerMetricHookClass · 0.90
is_rank_for_logMethod · 0.80
get_groupMethod · 0.80
is_using_ppMethod · 0.80
execute_scheduleMethod · 0.80
add_scalarMethod · 0.80
evalMethod · 0.45

Tested by

no test coverage detected