MCPcopy
hub / github.com/Audio-AGI/AudioSep / evaluate

Function evaluate

models/CLAP/training/lp_train.py:209–301  ·  view source on GitHub ↗
(model, data, epoch, args, tb_writer=None, extra_suffix="")

Source from the content-addressed store, hash-verified

207
208
209def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""):
210 metrics = {}
211 if not args.parallel_eval:
212 if not is_master(args):
213 return metrics
214 device = torch.device(args.device)
215 model.eval()
216
217 # CHANGE
218 # zero_shot_metrics = zero_shot_eval(model, data, epoch, args)
219 # metrics.update(zero_shot_metrics)
220 if is_master(args):
221 print("Evaluating...")
222 metric_names = args.lp_metrics.split(",")
223 eval_tool = LPMetrics(metric_names=metric_names)
224
225 autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress
226 if "val" in data and (
227 args.val_frequency
228 and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)
229 ):
230 if args.parallel_eval:
231 dataloader, sampler = data["val"].dataloader, data["val"].sampler
232 if args.distributed and sampler is not None:
233 sampler.set_epoch(epoch)
234 samples_per_val = dataloader.num_samples
235 else:
236 dataloader = data["val"].dataloader
237 num_samples = 0
238 samples_per_val = dataloader.num_samples
239
240 eval_info = {"pred": [], "target": []}
241 with torch.no_grad():
242 for i, batch in enumerate(dataloader):
243 audio = batch # contains mel_spec, wavform, and longer list
244 class_label = batch["class_label"]
245
246 # audio = audio.to(device=device, non_blocking=True)
247 class_label = class_label.to(device=device, non_blocking=True)
248
249 with autocast():
250 pred = model(audio, device=device)
251 if args.parallel_eval:
252 pred, class_label = lp_gather_features(
253 pred, class_label, args.world_size, args.horovod
254 )
255 eval_info["pred"].append(pred)
256 eval_info["target"].append(class_label)
257
258 num_samples += class_label.shape[0]
259
260 if (i % 100) == 0: # and i != 0:
261 logging.info(
262 f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]"
263 )
264
265 if is_master(args):
266 eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu()

Callers 1

mainFunction · 0.90

Calls 6

evaluate_merticsMethod · 0.95
LPMetricsClass · 0.90
lp_gather_featuresFunction · 0.90
is_masterFunction · 0.85
appendMethod · 0.80
updateMethod · 0.45

Tested by

no test coverage detected