MWER loss 实现

1、paraformer 负例采样策略

The implementation of Minimum Word Error Rate Training loss (MWER) based on negative sampling strategy from <Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition>

https://gist.github.com/TeaPoly/234429e6c2d74d10fcb4987bc541d528

def create_sampling_mask(log_softmax, n):
    """
    Generate sampling mask
    # Ref: <Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition>
    #       https://arxiv.org/abs/2206.08317
    Args:
        log_softmax: log softmax inputs, float32 (batch, maxlen_out, vocab_size)
        n: candidate paths num, int32
    Return:
        sampling_mask: the sampling mask (nbest, batch, maxlen_out, vocab_size)
    """
    b, s, v = log_softmax.size()

    # Generate random mask
    nbest_random_mask = torch.randint(
        0, 2, (n, b, s, v), device=log_softmax.device
    )

    # Greedy search decoding for best path
    top1_score_indices = log_softmax.argmax(dim=-1).squeeze(-1)

    # Genrate top 1 score token mask
    top1_score_indices_mask = torch.zeros(
        (b, s, v), dtype=torch.int).to(log_softmax.device)
    top1_score_indices_mask.scatter_(-1, top1_score_indices.unsqueeze(-1), 1)

    # Genrate sampling mask by applying random mask to top 1 score token
    sampling_mask = nbest_random_mask*top1_score_indices_mask.unsqueeze(0)

    return sampling_mask

2、CTC decoder mwer loss:

https://github.com/Mddct/losses/blob/main/py/mwer.py

关键: 计算前缀束搜索候选路径

self.ctc_prefix_beam_decoer = CTCDecoder(beam_width=beam_width,top_paths=beam_width)

3、其他:wenet

https://github.com/wenet-e2e/wenet/tree/main#

发表评论

您的电子邮箱地址不会被公开。 必填项已用*标注