MCPcopy Index your code
hub / github.com/modelscope/FunASR / sampler

Method sampler

funasr/models/paraformer/model.py:408–455  ·  view source on GitHub ↗

Sampler. Args: encoder_out: Encoder output tensor. encoder_out_lens: Encoder output lengths. ys_pad: TODO. ys_pad_lens: Lengths of ys_pad. pre_acoustic_embeds: TODO.

(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds)

Source from the content-addressed store, hash-verified

406 return loss_att, acc_att, cer_att, wer_att, loss_pre, pre_loss_att
407
408 def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
409
410 """Sampler.
411
412 Args:
413 encoder_out: Encoder output tensor.
414 encoder_out_lens: Encoder output lengths.
415 ys_pad: TODO.
416 ys_pad_lens: Lengths of ys_pad.
417 pre_acoustic_embeds: TODO.
418 """
419 tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(
420 ys_pad.device
421 )
422 ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
423 if self.share_embedding:
424 ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
425 else:
426 ys_pad_embed = self.decoder.embed(ys_pad_masked)
427 with torch.no_grad():
428 decoder_outs = self.decoder(
429 encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens
430 )
431 decoder_out, _ = decoder_outs[0], decoder_outs[1]
432 pred_tokens = decoder_out.argmax(-1)
433 nonpad_positions = ys_pad.ne(self.ignore_id)
434 seq_lens = (nonpad_positions).sum(1)
435 same_num = ((pred_tokens == ys_pad) & nonpad_positions).sum(1)
436 input_mask = torch.ones_like(nonpad_positions)
437 bsz, seq_len = ys_pad.size()
438 for li in range(bsz):
439 target_num = (
440 ((seq_lens[li] - same_num[li].sum()).float()) * self.sampling_ratio
441 ).long()
442 if target_num > 0:
443 input_mask[li].scatter_(
444 dim=0,
445 index=torch.randperm(seq_lens[li])[:target_num].to(input_mask.device),
446 value=0,
447 )
448 input_mask = input_mask.eq(1)
449 input_mask = input_mask.masked_fill(~nonpad_positions, False)
450 input_mask_expand_dim = input_mask.unsqueeze(2).to(pre_acoustic_embeds.device)
451
452 sematic_embeds = pre_acoustic_embeds.masked_fill(
453 ~input_mask_expand_dim, 0
454 ) + ys_pad_embed.masked_fill(input_mask_expand_dim, 0)
455 return sematic_embeds * tgt_mask, decoder_out * tgt_mask
456
457 def _calc_ctc_loss(
458 self,

Callers 2

_calc_att_lossMethod · 0.95
_calc_att_lossMethod · 0.45

Calls 2

make_pad_maskFunction · 0.90
argmaxMethod · 0.45

Tested by

no test coverage detected