yuancwang
init
b725c5a
raw
history blame
5.69 kB
from typing import List, Tuple
import torch
from modules.wenet_extractor.utils.common import log_add
class Sequence:
__slots__ = {"hyp", "score", "cache"}
def __init__(
self,
hyp: List[torch.Tensor],
score,
cache: List[torch.Tensor],
):
self.hyp = hyp
self.score = score
self.cache = cache
class PrefixBeamSearch:
def __init__(self, encoder, predictor, joint, ctc, blank):
self.encoder = encoder
self.predictor = predictor
self.joint = joint
self.ctc = ctc
self.blank = blank
def forward_decoder_one_step(
self, encoder_x: torch.Tensor, pre_t: torch.Tensor, cache: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
padding = torch.zeros(pre_t.size(0), 1, device=encoder_x.device)
pre_t, new_cache = self.predictor.forward_step(
pre_t.unsqueeze(-1), padding, cache
)
x = self.joint(encoder_x, pre_t) # [beam, 1, 1, vocab]
x = x.log_softmax(dim=-1)
return x, new_cache
def prefix_beam_search(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
decoding_chunk_size: int = -1,
beam_size: int = 5,
num_decoding_left_chunks: int = -1,
simulate_streaming: bool = False,
ctc_weight: float = 0.3,
transducer_weight: float = 0.7,
):
"""prefix beam search
also see wenet.transducer.transducer.beam_search
"""
assert speech.shape[0] == speech_lengths.shape[0]
assert decoding_chunk_size != 0
device = speech.device
batch_size = speech.shape[0]
assert batch_size == 1
# 1. Encoder
encoder_out, _ = self.encoder(
speech, speech_lengths, decoding_chunk_size, num_decoding_left_chunks
) # (B, maxlen, encoder_dim)
maxlen = encoder_out.size(1)
ctc_probs = self.ctc.log_softmax(encoder_out).squeeze(0)
beam_init: List[Sequence] = []
# 2. init beam using Sequence to save beam unit
cache = self.predictor.init_state(1, method="zero", device=device)
beam_init.append(Sequence(hyp=[self.blank], score=0.0, cache=cache))
# 3. start decoding (notice: we use breathwise first searching)
# !!!! In this decoding method: one frame do not output multi units. !!!!
# !!!! Experiments show that this strategy has little impact !!!!
for i in range(maxlen):
# 3.1 building input
# decoder taking the last token to predict the next token
input_hyp = [s.hyp[-1] for s in beam_init]
input_hyp_tensor = torch.tensor(input_hyp, dtype=torch.int, device=device)
# building statement from beam
cache_batch = self.predictor.cache_to_batch([s.cache for s in beam_init])
# build score tensor to do torch.add() function
scores = torch.tensor([s.score for s in beam_init]).to(device)
# 3.2 forward decoder
logp, new_cache = self.forward_decoder_one_step(
encoder_out[:, i, :].unsqueeze(1),
input_hyp_tensor,
cache_batch,
) # logp: (N, 1, 1, vocab_size)
logp = logp.squeeze(1).squeeze(1) # logp: (N, vocab_size)
new_cache = self.predictor.batch_to_cache(new_cache)
# 3.3 shallow fusion for transducer score
# and ctc score where we can also add the LM score
logp = torch.log(
torch.add(
transducer_weight * torch.exp(logp),
ctc_weight * torch.exp(ctc_probs[i].unsqueeze(0)),
)
)
# 3.4 first beam prune
top_k_logp, top_k_index = logp.topk(beam_size) # (N, N)
scores = torch.add(scores.unsqueeze(1), top_k_logp)
# 3.5 generate new beam (N*N)
beam_A = []
for j in range(len(beam_init)):
# update seq
base_seq = beam_init[j]
for t in range(beam_size):
# blank: only update the score
if top_k_index[j, t] == self.blank:
new_seq = Sequence(
hyp=base_seq.hyp.copy(),
score=scores[j, t].item(),
cache=base_seq.cache,
)
beam_A.append(new_seq)
# other unit: update hyp score statement and last
else:
hyp_new = base_seq.hyp.copy()
hyp_new.append(top_k_index[j, t].item())
new_seq = Sequence(
hyp=hyp_new, score=scores[j, t].item(), cache=new_cache[j]
)
beam_A.append(new_seq)
# 3.6 prefix fusion
fusion_A = [beam_A[0]]
for j in range(1, len(beam_A)):
s1 = beam_A[j]
if_do_append = True
for t in range(len(fusion_A)):
# notice: A_ can not fusion with A
if s1.hyp == fusion_A[t].hyp:
fusion_A[t].score = log_add([fusion_A[t].score, s1.score])
if_do_append = False
break
if if_do_append:
fusion_A.append(s1)
# 4. second pruned
fusion_A.sort(key=lambda x: x.score, reverse=True)
beam_init = fusion_A[:beam_size]
return beam_init, encoder_out