add sensevoice submodule
This commit is contained in:
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "SenseVoice"]
|
||||||
|
path = SenseVoice
|
||||||
|
url = https://github.com/FunAudioLLM/SenseVoice.git
|
||||||
1
SenseVoice
Submodule
1
SenseVoice
Submodule
Submodule SenseVoice added at 3ecc6f6a8f
12
json/602.md
12
json/602.md
@@ -1,12 +0,0 @@
|
|||||||

|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
Unable to convert audio to text.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
Unable to convert audio to text.
|
|
||||||
|
|
||||||
polyglot
|
|
||||||
|
|
||||||
929
model.py
929
model.py
@@ -1,929 +0,0 @@
|
|||||||
import time
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
|
||||||
from funasr.metrics.compute_acc import th_accuracy
|
|
||||||
from funasr.models.ctc.ctc import CTC
|
|
||||||
from funasr.register import tables
|
|
||||||
from funasr.train_utils.device_funcs import force_gatherable
|
|
||||||
from funasr.utils.datadir_writer import DatadirWriter
|
|
||||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from utils.ctc_alignment import ctc_forced_align
|
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPositionEncoder(torch.nn.Module):
|
|
||||||
""" """
|
|
||||||
|
|
||||||
def __int__(self, d_model=80, dropout_rate=0.1):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
|
|
||||||
):
|
|
||||||
batch_size = positions.size(0)
|
|
||||||
positions = positions.type(dtype)
|
|
||||||
device = positions.device
|
|
||||||
log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
|
|
||||||
depth / 2 - 1
|
|
||||||
)
|
|
||||||
inv_timescales = torch.exp(
|
|
||||||
torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
|
|
||||||
)
|
|
||||||
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
|
|
||||||
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
|
|
||||||
inv_timescales, [1, 1, -1]
|
|
||||||
)
|
|
||||||
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
|
||||||
return encoding.type(dtype)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
batch_size, timesteps, input_dim = x.size()
|
|
||||||
positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
|
|
||||||
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
|
||||||
|
|
||||||
return x + position_encoding
|
|
||||||
|
|
||||||
|
|
||||||
class PositionwiseFeedForward(torch.nn.Module):
|
|
||||||
"""Positionwise feed forward layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idim (int): Input dimenstion.
|
|
||||||
hidden_units (int): The number of hidden units.
|
|
||||||
dropout_rate (float): Dropout rate.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
|
||||||
"""Construct an PositionwiseFeedForward object."""
|
|
||||||
super(PositionwiseFeedForward, self).__init__()
|
|
||||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
|
||||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
|
||||||
self.activation = activation
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""Forward function."""
|
|
||||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadedAttentionSANM(nn.Module):
|
|
||||||
"""Multi-Head Attention layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
n_head (int): The number of heads.
|
|
||||||
n_feat (int): The number of features.
|
|
||||||
dropout_rate (float): Dropout rate.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
n_head,
|
|
||||||
in_feat,
|
|
||||||
n_feat,
|
|
||||||
dropout_rate,
|
|
||||||
kernel_size,
|
|
||||||
sanm_shfit=0,
|
|
||||||
lora_list=None,
|
|
||||||
lora_rank=8,
|
|
||||||
lora_alpha=16,
|
|
||||||
lora_dropout=0.1,
|
|
||||||
):
|
|
||||||
"""Construct an MultiHeadedAttention object."""
|
|
||||||
super().__init__()
|
|
||||||
assert n_feat % n_head == 0
|
|
||||||
# We assume d_v always equals d_k
|
|
||||||
self.d_k = n_feat // n_head
|
|
||||||
self.h = n_head
|
|
||||||
# self.linear_q = nn.Linear(n_feat, n_feat)
|
|
||||||
# self.linear_k = nn.Linear(n_feat, n_feat)
|
|
||||||
# self.linear_v = nn.Linear(n_feat, n_feat)
|
|
||||||
|
|
||||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
|
||||||
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
|
||||||
self.attn = None
|
|
||||||
self.dropout = nn.Dropout(p=dropout_rate)
|
|
||||||
|
|
||||||
self.fsmn_block = nn.Conv1d(
|
|
||||||
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
|
|
||||||
)
|
|
||||||
# padding
|
|
||||||
left_padding = (kernel_size - 1) // 2
|
|
||||||
if sanm_shfit > 0:
|
|
||||||
left_padding = left_padding + sanm_shfit
|
|
||||||
right_padding = kernel_size - 1 - left_padding
|
|
||||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
|
||||||
|
|
||||||
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
|
|
||||||
b, t, d = inputs.size()
|
|
||||||
if mask is not None:
|
|
||||||
mask = torch.reshape(mask, (b, -1, 1))
|
|
||||||
if mask_shfit_chunk is not None:
|
|
||||||
mask = mask * mask_shfit_chunk
|
|
||||||
inputs = inputs * mask
|
|
||||||
|
|
||||||
x = inputs.transpose(1, 2)
|
|
||||||
x = self.pad_fn(x)
|
|
||||||
x = self.fsmn_block(x)
|
|
||||||
x = x.transpose(1, 2)
|
|
||||||
x += inputs
|
|
||||||
x = self.dropout(x)
|
|
||||||
if mask is not None:
|
|
||||||
x = x * mask
|
|
||||||
return x
|
|
||||||
|
|
||||||
def forward_qkv(self, x):
|
|
||||||
"""Transform query, key and value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
||||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
||||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
|
||||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
|
||||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
|
||||||
|
|
||||||
"""
|
|
||||||
b, t, d = x.size()
|
|
||||||
q_k_v = self.linear_q_k_v(x)
|
|
||||||
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
|
||||||
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, head, time1, d_k)
|
|
||||||
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, head, time2, d_k)
|
|
||||||
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
|
|
||||||
1, 2
|
|
||||||
) # (batch, head, time2, d_k)
|
|
||||||
|
|
||||||
return q_h, k_h, v_h, v
|
|
||||||
|
|
||||||
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
|
|
||||||
"""Compute attention context vector.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
|
||||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
|
||||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
|
||||||
weighted by the attention score (#batch, time1, time2).
|
|
||||||
|
|
||||||
"""
|
|
||||||
n_batch = value.size(0)
|
|
||||||
if mask is not None:
|
|
||||||
if mask_att_chunk_encoder is not None:
|
|
||||||
mask = mask * mask_att_chunk_encoder
|
|
||||||
|
|
||||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
|
||||||
|
|
||||||
min_value = -float(
|
|
||||||
"inf"
|
|
||||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
|
||||||
scores = scores.masked_fill(mask, min_value)
|
|
||||||
attn = torch.softmax(scores, dim=-1).masked_fill(
|
|
||||||
mask, 0.0
|
|
||||||
) # (batch, head, time1, time2)
|
|
||||||
else:
|
|
||||||
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
|
||||||
|
|
||||||
p_attn = self.dropout(attn)
|
|
||||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
|
||||||
x = (
|
|
||||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
|
||||||
) # (batch, time1, d_model)
|
|
||||||
|
|
||||||
return self.linear_out(x) # (batch, time1, d_model)
|
|
||||||
|
|
||||||
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
|
||||||
"""Compute scaled dot product attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
||||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
||||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
||||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
||||||
(#batch, time1, time2).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
||||||
|
|
||||||
"""
|
|
||||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
|
||||||
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
|
|
||||||
q_h = q_h * self.d_k ** (-0.5)
|
|
||||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
|
||||||
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
|
||||||
return att_outs + fsmn_memory
|
|
||||||
|
|
||||||
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
|
|
||||||
"""Compute scaled dot product attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
|
||||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
|
||||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
|
||||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
|
||||||
(#batch, time1, time2).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
|
||||||
|
|
||||||
"""
|
|
||||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
|
||||||
if chunk_size is not None and look_back > 0 or look_back == -1:
|
|
||||||
if cache is not None:
|
|
||||||
k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
|
|
||||||
v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
|
|
||||||
k_h = torch.cat((cache["k"], k_h), dim=2)
|
|
||||||
v_h = torch.cat((cache["v"], v_h), dim=2)
|
|
||||||
|
|
||||||
cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
|
|
||||||
cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
|
|
||||||
if look_back != -1:
|
|
||||||
cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]):, :]
|
|
||||||
cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]):, :]
|
|
||||||
else:
|
|
||||||
cache_tmp = {
|
|
||||||
"k": k_h[:, :, : -(chunk_size[2]), :],
|
|
||||||
"v": v_h[:, :, : -(chunk_size[2]), :],
|
|
||||||
}
|
|
||||||
cache = cache_tmp
|
|
||||||
fsmn_memory = self.forward_fsmn(v, None)
|
|
||||||
q_h = q_h * self.d_k ** (-0.5)
|
|
||||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
|
||||||
att_outs = self.forward_attention(v_h, scores, None)
|
|
||||||
return att_outs + fsmn_memory, cache
|
|
||||||
|
|
||||||
|
|
||||||
class LayerNorm(nn.LayerNorm):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
output = F.layer_norm(
|
|
||||||
input.float(),
|
|
||||||
self.normalized_shape,
|
|
||||||
self.weight.float() if self.weight is not None else None,
|
|
||||||
self.bias.float() if self.bias is not None else None,
|
|
||||||
self.eps,
|
|
||||||
)
|
|
||||||
return output.type_as(input)
|
|
||||||
|
|
||||||
|
|
||||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
|
||||||
if maxlen is None:
|
|
||||||
maxlen = lengths.max()
|
|
||||||
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
|
||||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
|
||||||
mask = row_vector < matrix
|
|
||||||
mask = mask.detach()
|
|
||||||
|
|
||||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderLayerSANM(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_size,
|
|
||||||
size,
|
|
||||||
self_attn,
|
|
||||||
feed_forward,
|
|
||||||
dropout_rate,
|
|
||||||
normalize_before=True,
|
|
||||||
concat_after=False,
|
|
||||||
stochastic_depth_rate=0.0,
|
|
||||||
):
|
|
||||||
"""Construct an EncoderLayer object."""
|
|
||||||
super(EncoderLayerSANM, self).__init__()
|
|
||||||
self.self_attn = self_attn
|
|
||||||
self.feed_forward = feed_forward
|
|
||||||
self.norm1 = LayerNorm(in_size)
|
|
||||||
self.norm2 = LayerNorm(size)
|
|
||||||
self.dropout = nn.Dropout(dropout_rate)
|
|
||||||
self.in_size = in_size
|
|
||||||
self.size = size
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
self.concat_after = concat_after
|
|
||||||
if self.concat_after:
|
|
||||||
self.concat_linear = nn.Linear(size + size, size)
|
|
||||||
self.stochastic_depth_rate = stochastic_depth_rate
|
|
||||||
self.dropout_rate = dropout_rate
|
|
||||||
|
|
||||||
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
|
||||||
"""Compute encoded features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
|
||||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
|
||||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Output tensor (#batch, time, size).
|
|
||||||
torch.Tensor: Mask tensor (#batch, time).
|
|
||||||
|
|
||||||
"""
|
|
||||||
skip_layer = False
|
|
||||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
|
||||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
|
||||||
stoch_layer_coeff = 1.0
|
|
||||||
if self.training and self.stochastic_depth_rate > 0:
|
|
||||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
|
||||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
|
||||||
|
|
||||||
if skip_layer:
|
|
||||||
if cache is not None:
|
|
||||||
x = torch.cat([cache, x], dim=1)
|
|
||||||
return x, mask
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.norm1(x)
|
|
||||||
|
|
||||||
if self.concat_after:
|
|
||||||
x_concat = torch.cat(
|
|
||||||
(
|
|
||||||
x,
|
|
||||||
self.self_attn(
|
|
||||||
x,
|
|
||||||
mask,
|
|
||||||
mask_shfit_chunk=mask_shfit_chunk,
|
|
||||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
if self.in_size == self.size:
|
|
||||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
|
||||||
else:
|
|
||||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
|
||||||
else:
|
|
||||||
if self.in_size == self.size:
|
|
||||||
x = residual + stoch_layer_coeff * self.dropout(
|
|
||||||
self.self_attn(
|
|
||||||
x,
|
|
||||||
mask,
|
|
||||||
mask_shfit_chunk=mask_shfit_chunk,
|
|
||||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
x = stoch_layer_coeff * self.dropout(
|
|
||||||
self.self_attn(
|
|
||||||
x,
|
|
||||||
mask,
|
|
||||||
mask_shfit_chunk=mask_shfit_chunk,
|
|
||||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm1(x)
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm2(x)
|
|
||||||
|
|
||||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
|
||||||
|
|
||||||
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
|
|
||||||
"""Compute encoded features.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
|
||||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
|
||||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Output tensor (#batch, time, size).
|
|
||||||
torch.Tensor: Mask tensor (#batch, time).
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.norm1(x)
|
|
||||||
|
|
||||||
if self.in_size == self.size:
|
|
||||||
attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
|
|
||||||
x = residual + attn
|
|
||||||
else:
|
|
||||||
x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
|
|
||||||
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm1(x)
|
|
||||||
|
|
||||||
residual = x
|
|
||||||
if self.normalize_before:
|
|
||||||
x = self.norm2(x)
|
|
||||||
x = residual + self.feed_forward(x)
|
|
||||||
if not self.normalize_before:
|
|
||||||
x = self.norm2(x)
|
|
||||||
|
|
||||||
return x, cache
|
|
||||||
|
|
||||||
|
|
||||||
@tables.register("encoder_classes", "SenseVoiceEncoderSmall")
|
|
||||||
class SenseVoiceEncoderSmall(nn.Module):
|
|
||||||
"""
|
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
|
||||||
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
|
|
||||||
https://arxiv.org/abs/2006.01713
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
input_size: int,
|
|
||||||
output_size: int = 256,
|
|
||||||
attention_heads: int = 4,
|
|
||||||
linear_units: int = 2048,
|
|
||||||
num_blocks: int = 6,
|
|
||||||
tp_blocks: int = 0,
|
|
||||||
dropout_rate: float = 0.1,
|
|
||||||
positional_dropout_rate: float = 0.1,
|
|
||||||
attention_dropout_rate: float = 0.0,
|
|
||||||
stochastic_depth_rate: float = 0.0,
|
|
||||||
input_layer: Optional[str] = "conv2d",
|
|
||||||
pos_enc_class=SinusoidalPositionEncoder,
|
|
||||||
normalize_before: bool = True,
|
|
||||||
concat_after: bool = False,
|
|
||||||
positionwise_layer_type: str = "linear",
|
|
||||||
positionwise_conv_kernel_size: int = 1,
|
|
||||||
padding_idx: int = -1,
|
|
||||||
kernel_size: int = 11,
|
|
||||||
sanm_shfit: int = 0,
|
|
||||||
selfattention_layer_type: str = "sanm",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._output_size = output_size
|
|
||||||
|
|
||||||
self.embed = SinusoidalPositionEncoder()
|
|
||||||
|
|
||||||
self.normalize_before = normalize_before
|
|
||||||
|
|
||||||
positionwise_layer = PositionwiseFeedForward
|
|
||||||
positionwise_layer_args = (
|
|
||||||
output_size,
|
|
||||||
linear_units,
|
|
||||||
dropout_rate,
|
|
||||||
)
|
|
||||||
|
|
||||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
|
||||||
encoder_selfattn_layer_args0 = (
|
|
||||||
attention_heads,
|
|
||||||
input_size,
|
|
||||||
output_size,
|
|
||||||
attention_dropout_rate,
|
|
||||||
kernel_size,
|
|
||||||
sanm_shfit,
|
|
||||||
)
|
|
||||||
encoder_selfattn_layer_args = (
|
|
||||||
attention_heads,
|
|
||||||
output_size,
|
|
||||||
output_size,
|
|
||||||
attention_dropout_rate,
|
|
||||||
kernel_size,
|
|
||||||
sanm_shfit,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.encoders0 = nn.ModuleList(
|
|
||||||
[
|
|
||||||
EncoderLayerSANM(
|
|
||||||
input_size,
|
|
||||||
output_size,
|
|
||||||
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
|
||||||
positionwise_layer(*positionwise_layer_args),
|
|
||||||
dropout_rate,
|
|
||||||
)
|
|
||||||
for i in range(1)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.encoders = nn.ModuleList(
|
|
||||||
[
|
|
||||||
EncoderLayerSANM(
|
|
||||||
output_size,
|
|
||||||
output_size,
|
|
||||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
||||||
positionwise_layer(*positionwise_layer_args),
|
|
||||||
dropout_rate,
|
|
||||||
)
|
|
||||||
for i in range(num_blocks - 1)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.tp_encoders = nn.ModuleList(
|
|
||||||
[
|
|
||||||
EncoderLayerSANM(
|
|
||||||
output_size,
|
|
||||||
output_size,
|
|
||||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
|
||||||
positionwise_layer(*positionwise_layer_args),
|
|
||||||
dropout_rate,
|
|
||||||
)
|
|
||||||
for i in range(tp_blocks)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.after_norm = LayerNorm(output_size)
|
|
||||||
|
|
||||||
self.tp_norm = LayerNorm(output_size)
|
|
||||||
|
|
||||||
def output_size(self) -> int:
|
|
||||||
return self._output_size
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
xs_pad: torch.Tensor,
|
|
||||||
ilens: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""Embed positions in tensor."""
|
|
||||||
masks = sequence_mask(ilens, device=ilens.device)[:, None, :]
|
|
||||||
|
|
||||||
xs_pad *= self.output_size() ** 0.5
|
|
||||||
|
|
||||||
xs_pad = self.embed(xs_pad)
|
|
||||||
|
|
||||||
# forward encoder1
|
|
||||||
for layer_idx, encoder_layer in enumerate(self.encoders0):
|
|
||||||
encoder_outs = encoder_layer(xs_pad, masks)
|
|
||||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
|
||||||
|
|
||||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
|
||||||
encoder_outs = encoder_layer(xs_pad, masks)
|
|
||||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
|
||||||
|
|
||||||
xs_pad = self.after_norm(xs_pad)
|
|
||||||
|
|
||||||
# forward encoder2
|
|
||||||
olens = masks.squeeze(1).sum(1).int()
|
|
||||||
|
|
||||||
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
|
|
||||||
encoder_outs = encoder_layer(xs_pad, masks)
|
|
||||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
|
||||||
|
|
||||||
xs_pad = self.tp_norm(xs_pad)
|
|
||||||
return xs_pad, olens
|
|
||||||
|
|
||||||
|
|
||||||
@tables.register("model_classes", "SenseVoiceSmall")
|
|
||||||
class SenseVoiceSmall(nn.Module):
|
|
||||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
specaug: str = None,
|
|
||||||
specaug_conf: dict = None,
|
|
||||||
normalize: str = None,
|
|
||||||
normalize_conf: dict = None,
|
|
||||||
encoder: str = None,
|
|
||||||
encoder_conf: dict = None,
|
|
||||||
ctc_conf: dict = None,
|
|
||||||
input_size: int = 80,
|
|
||||||
vocab_size: int = -1,
|
|
||||||
ignore_id: int = -1,
|
|
||||||
blank_id: int = 0,
|
|
||||||
sos: int = 1,
|
|
||||||
eos: int = 2,
|
|
||||||
length_normalized_loss: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
if specaug is not None:
|
|
||||||
specaug_class = tables.specaug_classes.get(specaug)
|
|
||||||
specaug = specaug_class(**specaug_conf)
|
|
||||||
if normalize is not None:
|
|
||||||
normalize_class = tables.normalize_classes.get(normalize)
|
|
||||||
normalize = normalize_class(**normalize_conf)
|
|
||||||
encoder_class = tables.encoder_classes.get(encoder)
|
|
||||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
|
||||||
encoder_output_size = encoder.output_size()
|
|
||||||
|
|
||||||
if ctc_conf is None:
|
|
||||||
ctc_conf = {}
|
|
||||||
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
|
|
||||||
|
|
||||||
self.blank_id = blank_id
|
|
||||||
self.sos = sos if sos is not None else vocab_size - 1
|
|
||||||
self.eos = eos if eos is not None else vocab_size - 1
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
self.ignore_id = ignore_id
|
|
||||||
self.specaug = specaug
|
|
||||||
self.normalize = normalize
|
|
||||||
self.encoder = encoder
|
|
||||||
self.error_calculator = None
|
|
||||||
|
|
||||||
self.ctc = ctc
|
|
||||||
|
|
||||||
self.length_normalized_loss = length_normalized_loss
|
|
||||||
self.encoder_output_size = encoder_output_size
|
|
||||||
|
|
||||||
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
|
|
||||||
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
|
|
||||||
self.textnorm_dict = {"withitn": 14, "woitn": 15}
|
|
||||||
self.textnorm_int_dict = {25016: 14, 25017: 15}
|
|
||||||
self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size)
|
|
||||||
self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
|
|
||||||
|
|
||||||
self.criterion_att = LabelSmoothingLoss(
|
|
||||||
size=self.vocab_size,
|
|
||||||
padding_idx=self.ignore_id,
|
|
||||||
smoothing=kwargs.get("lsm_weight", 0.0),
|
|
||||||
normalize_length=self.length_normalized_loss,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_pretrained(model: str = None, **kwargs):
|
|
||||||
from funasr import AutoModel
|
|
||||||
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
|
|
||||||
|
|
||||||
return model, kwargs
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
speech: torch.Tensor,
|
|
||||||
speech_lengths: torch.Tensor,
|
|
||||||
text: torch.Tensor,
|
|
||||||
text_lengths: torch.Tensor,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""Encoder + Decoder + Calc loss
|
|
||||||
Args:
|
|
||||||
speech: (Batch, Length, ...)
|
|
||||||
speech_lengths: (Batch, )
|
|
||||||
text: (Batch, Length)
|
|
||||||
text_lengths: (Batch,)
|
|
||||||
"""
|
|
||||||
# import pdb;
|
|
||||||
# pdb.set_trace()
|
|
||||||
if len(text_lengths.size()) > 1:
|
|
||||||
text_lengths = text_lengths[:, 0]
|
|
||||||
if len(speech_lengths.size()) > 1:
|
|
||||||
speech_lengths = speech_lengths[:, 0]
|
|
||||||
|
|
||||||
batch_size = speech.shape[0]
|
|
||||||
|
|
||||||
# 1. Encoder
|
|
||||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
|
|
||||||
|
|
||||||
loss_ctc, cer_ctc = None, None
|
|
||||||
loss_rich, acc_rich = None, None
|
|
||||||
stats = dict()
|
|
||||||
|
|
||||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
|
||||||
encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_rich, acc_rich = self._calc_rich_ce_loss(
|
|
||||||
encoder_out[:, :4, :], text[:, :4]
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = loss_ctc + loss_rich
|
|
||||||
# Collect total loss stats
|
|
||||||
stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
|
|
||||||
stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
|
|
||||||
stats["loss"] = torch.clone(loss.detach()) if loss is not None else None
|
|
||||||
stats["acc_rich"] = acc_rich
|
|
||||||
|
|
||||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
|
||||||
if self.length_normalized_loss:
|
|
||||||
batch_size = int((text_lengths + 1).sum())
|
|
||||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
|
||||||
return loss, stats, weight
|
|
||||||
|
|
||||||
def encode(
|
|
||||||
self,
|
|
||||||
speech: torch.Tensor,
|
|
||||||
speech_lengths: torch.Tensor,
|
|
||||||
text: torch.Tensor,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
|
||||||
Args:
|
|
||||||
speech: (Batch, Length, ...)
|
|
||||||
speech_lengths: (Batch, )
|
|
||||||
ind: int
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Data augmentation
|
|
||||||
if self.specaug is not None and self.training:
|
|
||||||
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
|
||||||
|
|
||||||
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
|
||||||
if self.normalize is not None:
|
|
||||||
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
|
||||||
|
|
||||||
lids = torch.LongTensor(
|
|
||||||
[[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0] for lid in
|
|
||||||
text[:, 0]]).to(speech.device)
|
|
||||||
language_query = self.embed(lids)
|
|
||||||
|
|
||||||
styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device)
|
|
||||||
style_query = self.embed(styles)
|
|
||||||
speech = torch.cat((style_query, speech), dim=1)
|
|
||||||
speech_lengths += 1
|
|
||||||
|
|
||||||
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
|
|
||||||
input_query = torch.cat((language_query, event_emo_query), dim=1)
|
|
||||||
speech = torch.cat((input_query, speech), dim=1)
|
|
||||||
speech_lengths += 3
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
|
|
||||||
|
|
||||||
return encoder_out, encoder_out_lens
|
|
||||||
|
|
||||||
def _calc_ctc_loss(
|
|
||||||
self,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
encoder_out_lens: torch.Tensor,
|
|
||||||
ys_pad: torch.Tensor,
|
|
||||||
ys_pad_lens: torch.Tensor,
|
|
||||||
):
|
|
||||||
# Calc CTC loss
|
|
||||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
|
||||||
|
|
||||||
# Calc CER using CTC
|
|
||||||
cer_ctc = None
|
|
||||||
if not self.training and self.error_calculator is not None:
|
|
||||||
ys_hat = self.ctc.argmax(encoder_out).data
|
|
||||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
|
||||||
return loss_ctc, cer_ctc
|
|
||||||
|
|
||||||
def _calc_rich_ce_loss(
|
|
||||||
self,
|
|
||||||
encoder_out: torch.Tensor,
|
|
||||||
ys_pad: torch.Tensor,
|
|
||||||
):
|
|
||||||
decoder_out = self.ctc.ctc_lo(encoder_out)
|
|
||||||
# 2. Compute attention loss
|
|
||||||
loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous())
|
|
||||||
acc_rich = th_accuracy(
|
|
||||||
decoder_out.view(-1, self.vocab_size),
|
|
||||||
ys_pad.contiguous(),
|
|
||||||
ignore_label=self.ignore_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss_rich, acc_rich
|
|
||||||
|
|
||||||
def inference(
|
|
||||||
self,
|
|
||||||
data_in,
|
|
||||||
data_lengths=None,
|
|
||||||
key: list = ["wav_file_tmp_name"],
|
|
||||||
tokenizer=None,
|
|
||||||
frontend=None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
|
|
||||||
meta_data = {}
|
|
||||||
if (
|
|
||||||
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
|
|
||||||
): # fbank
|
|
||||||
speech, speech_lengths = data_in, data_lengths
|
|
||||||
if len(speech.shape) < 3:
|
|
||||||
speech = speech[None, :, :]
|
|
||||||
if speech_lengths is None:
|
|
||||||
speech_lengths = speech.shape[1]
|
|
||||||
else:
|
|
||||||
# extract fbank feats
|
|
||||||
time1 = time.perf_counter()
|
|
||||||
audio_sample_list = load_audio_text_image_video(
|
|
||||||
data_in,
|
|
||||||
fs=frontend.fs,
|
|
||||||
audio_fs=kwargs.get("fs", 16000),
|
|
||||||
data_type=kwargs.get("data_type", "sound"),
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
)
|
|
||||||
time2 = time.perf_counter()
|
|
||||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
|
||||||
speech, speech_lengths = extract_fbank(
|
|
||||||
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
|
|
||||||
)
|
|
||||||
time3 = time.perf_counter()
|
|
||||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
|
||||||
meta_data["batch_data_time"] = (
|
|
||||||
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
|
||||||
)
|
|
||||||
|
|
||||||
speech = speech.to(device=kwargs["device"])
|
|
||||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
|
||||||
|
|
||||||
language = kwargs.get("language", "auto")
|
|
||||||
language_query = self.embed(
|
|
||||||
torch.LongTensor(
|
|
||||||
[[self.lid_dict[language] if language in self.lid_dict else 0]]
|
|
||||||
).to(speech.device)
|
|
||||||
).repeat(speech.size(0), 1, 1)
|
|
||||||
|
|
||||||
use_itn = kwargs.get("use_itn", False)
|
|
||||||
output_timestamp = kwargs.get("output_timestamp", False)
|
|
||||||
|
|
||||||
textnorm = kwargs.get("text_norm", None)
|
|
||||||
if textnorm is None:
|
|
||||||
textnorm = "withitn" if use_itn else "woitn"
|
|
||||||
textnorm_query = self.embed(
|
|
||||||
torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
|
|
||||||
).repeat(speech.size(0), 1, 1)
|
|
||||||
speech = torch.cat((textnorm_query, speech), dim=1)
|
|
||||||
speech_lengths += 1
|
|
||||||
|
|
||||||
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
|
|
||||||
speech.size(0), 1, 1
|
|
||||||
)
|
|
||||||
input_query = torch.cat((language_query, event_emo_query), dim=1)
|
|
||||||
speech = torch.cat((input_query, speech), dim=1)
|
|
||||||
speech_lengths += 3
|
|
||||||
|
|
||||||
# Encoder
|
|
||||||
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
|
|
||||||
if isinstance(encoder_out, tuple):
|
|
||||||
encoder_out = encoder_out[0]
|
|
||||||
|
|
||||||
# c. Passed the encoder result and the beam search
|
|
||||||
ctc_logits = self.ctc.log_softmax(encoder_out)
|
|
||||||
if kwargs.get("ban_emo_unk", False):
|
|
||||||
ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf")
|
|
||||||
|
|
||||||
results = []
|
|
||||||
b, n, d = encoder_out.size()
|
|
||||||
if isinstance(key[0], (list, tuple)):
|
|
||||||
key = key[0]
|
|
||||||
if len(key) < b:
|
|
||||||
key = key * b
|
|
||||||
for i in range(b):
|
|
||||||
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
|
|
||||||
yseq = x.argmax(dim=-1)
|
|
||||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
|
||||||
|
|
||||||
ibest_writer = None
|
|
||||||
if kwargs.get("output_dir") is not None:
|
|
||||||
if not hasattr(self, "writer"):
|
|
||||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
|
||||||
ibest_writer = self.writer[f"1best_recog"]
|
|
||||||
|
|
||||||
mask = yseq != self.blank_id
|
|
||||||
token_int = yseq[mask].tolist()
|
|
||||||
|
|
||||||
# Change integer-ids to tokens
|
|
||||||
text = tokenizer.decode(token_int)
|
|
||||||
if ibest_writer is not None:
|
|
||||||
ibest_writer["text"][key[i]] = text
|
|
||||||
|
|
||||||
if output_timestamp:
|
|
||||||
from itertools import groupby
|
|
||||||
timestamp = []
|
|
||||||
tokens = tokenizer.text2tokens(text)[4:]
|
|
||||||
|
|
||||||
logits_speech = self.ctc.softmax(encoder_out)[i, 4:encoder_out_lens[i].item(), :]
|
|
||||||
|
|
||||||
pred = logits_speech.argmax(-1).cpu()
|
|
||||||
logits_speech[pred == self.blank_id, self.blank_id] = 0
|
|
||||||
|
|
||||||
align = ctc_forced_align(
|
|
||||||
logits_speech.unsqueeze(0).float(),
|
|
||||||
torch.Tensor(token_int[4:]).unsqueeze(0).long().to(logits_speech.device),
|
|
||||||
(encoder_out_lens - 4).long(),
|
|
||||||
torch.tensor(len(token_int) - 4).unsqueeze(0).long().to(logits_speech.device),
|
|
||||||
ignore_id=self.ignore_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
pred = groupby(align[0, :encoder_out_lens[0]])
|
|
||||||
_start = 0
|
|
||||||
token_id = 0
|
|
||||||
ts_max = encoder_out_lens[i] - 4
|
|
||||||
for pred_token, pred_frame in pred:
|
|
||||||
_end = _start + len(list(pred_frame))
|
|
||||||
if pred_token != 0:
|
|
||||||
ts_left = max((_start * 60 - 30) / 1000, 0)
|
|
||||||
ts_right = min((_end * 60 - 30) / 1000, (ts_max * 60 - 30) / 1000)
|
|
||||||
timestamp.append([tokens[token_id], ts_left, ts_right])
|
|
||||||
token_id += 1
|
|
||||||
_start = _end
|
|
||||||
|
|
||||||
result_i = {"key": key[i], "text": text, "timestamp": timestamp}
|
|
||||||
results.append(result_i)
|
|
||||||
else:
|
|
||||||
result_i = {"key": key[i], "text": text}
|
|
||||||
results.append(result_i)
|
|
||||||
return results, meta_data
|
|
||||||
|
|
||||||
def export(self, **kwargs):
|
|
||||||
from export_meta import export_rebuild_model
|
|
||||||
|
|
||||||
if "max_seq_len" not in kwargs:
|
|
||||||
kwargs["max_seq_len"] = 512
|
|
||||||
models = export_rebuild_model(model=self, **kwargs)
|
|
||||||
return models
|
|
||||||
@@ -36,7 +36,7 @@ def transcribe_audio_funasr(audio_path, device="cuda:0"):
|
|||||||
model = AutoModel(
|
model = AutoModel(
|
||||||
model="iic/SenseVoiceSmall",
|
model="iic/SenseVoiceSmall",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
remote_code="./model.py", # Make sure this file is accessible
|
remote_code="./SenseVoice/model.py", # Make sure this file is accessible
|
||||||
vad_model="fsmn-vad",
|
vad_model="fsmn-vad",
|
||||||
vad_kwargs={"max_single_segment_time": 30000},
|
vad_kwargs={"max_single_segment_time": 30000},
|
||||||
device=device,
|
device=device,
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
def ctc_forced_align(
|
|
||||||
log_probs: torch.Tensor,
|
|
||||||
targets: torch.Tensor,
|
|
||||||
input_lengths: torch.Tensor,
|
|
||||||
target_lengths: torch.Tensor,
|
|
||||||
blank: int = 0,
|
|
||||||
ignore_id: int = -1,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Align a CTC label sequence to an emission.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_probs (Tensor): log probability of CTC emission output.
|
|
||||||
Tensor of shape `(B, T, C)`. where `B` is the batch size, `T` is the input length,
|
|
||||||
`C` is the number of characters in alphabet including blank.
|
|
||||||
targets (Tensor): Target sequence. Tensor of shape `(B, L)`,
|
|
||||||
where `L` is the target length.
|
|
||||||
input_lengths (Tensor):
|
|
||||||
Lengths of the inputs (max value must each be <= `T`). 1-D Tensor of shape `(B,)`.
|
|
||||||
target_lengths (Tensor):
|
|
||||||
Lengths of the targets. 1-D Tensor of shape `(B,)`.
|
|
||||||
blank_id (int, optional): The index of blank symbol in CTC emission. (Default: 0)
|
|
||||||
ignore_id (int, optional): The index of ignore symbol in CTC emission. (Default: -1)
|
|
||||||
"""
|
|
||||||
targets[targets == ignore_id] = blank
|
|
||||||
|
|
||||||
batch_size, input_time_size, _ = log_probs.size()
|
|
||||||
bsz_indices = torch.arange(batch_size, device=input_lengths.device)
|
|
||||||
|
|
||||||
_t_a_r_g_e_t_s_ = torch.cat(
|
|
||||||
(
|
|
||||||
torch.stack((torch.full_like(targets, blank), targets), dim=-1).flatten(start_dim=1),
|
|
||||||
torch.full_like(targets[:, :1], blank),
|
|
||||||
),
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
diff_labels = torch.cat(
|
|
||||||
(
|
|
||||||
torch.as_tensor([[False, False]], device=targets.device).expand(batch_size, -1),
|
|
||||||
_t_a_r_g_e_t_s_[:, 2:] != _t_a_r_g_e_t_s_[:, :-2],
|
|
||||||
),
|
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
neg_inf = torch.tensor(float("-inf"), device=log_probs.device, dtype=log_probs.dtype)
|
|
||||||
padding_num = 2
|
|
||||||
padded_t = padding_num + _t_a_r_g_e_t_s_.size(-1)
|
|
||||||
best_score = torch.full((batch_size, padded_t), neg_inf, device=log_probs.device, dtype=log_probs.dtype)
|
|
||||||
best_score[:, padding_num + 0] = log_probs[:, 0, blank]
|
|
||||||
best_score[:, padding_num + 1] = log_probs[bsz_indices, 0, _t_a_r_g_e_t_s_[:, 1]]
|
|
||||||
|
|
||||||
backpointers = torch.zeros((batch_size, input_time_size, padded_t), device=log_probs.device, dtype=targets.dtype)
|
|
||||||
|
|
||||||
for t in range(1, input_time_size):
|
|
||||||
prev = torch.stack(
|
|
||||||
(best_score[:, 2:], best_score[:, 1:-1], torch.where(diff_labels, best_score[:, :-2], neg_inf))
|
|
||||||
)
|
|
||||||
prev_max_value, prev_max_idx = prev.max(dim=0)
|
|
||||||
best_score[:, padding_num:] = log_probs[:, t].gather(-1, _t_a_r_g_e_t_s_) + prev_max_value
|
|
||||||
backpointers[:, t, padding_num:] = prev_max_idx
|
|
||||||
|
|
||||||
l1l2 = best_score.gather(
|
|
||||||
-1, torch.stack((padding_num + target_lengths * 2 - 1, padding_num + target_lengths * 2), dim=-1)
|
|
||||||
)
|
|
||||||
|
|
||||||
path = torch.zeros((batch_size, input_time_size), device=best_score.device, dtype=torch.long)
|
|
||||||
path[bsz_indices, input_lengths - 1] = padding_num + target_lengths * 2 - 1 + l1l2.argmax(dim=-1)
|
|
||||||
|
|
||||||
for t in range(input_time_size - 1, 0, -1):
|
|
||||||
target_indices = path[:, t]
|
|
||||||
prev_max_idx = backpointers[bsz_indices, t, target_indices]
|
|
||||||
path[:, t - 1] += target_indices - prev_max_idx
|
|
||||||
|
|
||||||
alignments = _t_a_r_g_e_t_s_.gather(dim=-1, index=(path - padding_num).clamp(min=0))
|
|
||||||
return alignments
|
|
||||||
@@ -1,73 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def export(
|
|
||||||
model, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
|
|
||||||
):
|
|
||||||
model_scripts = model.export(**kwargs)
|
|
||||||
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
|
|
||||||
os.makedirs(export_dir, exist_ok=True)
|
|
||||||
|
|
||||||
if not isinstance(model_scripts, (list, tuple)):
|
|
||||||
model_scripts = (model_scripts,)
|
|
||||||
for m in model_scripts:
|
|
||||||
m.eval()
|
|
||||||
if type == "onnx":
|
|
||||||
_onnx(
|
|
||||||
m,
|
|
||||||
quantize=quantize,
|
|
||||||
opset_version=opset_version,
|
|
||||||
export_dir=export_dir,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
print("output dir: {}".format(export_dir))
|
|
||||||
|
|
||||||
return export_dir
|
|
||||||
|
|
||||||
|
|
||||||
def _onnx(
|
|
||||||
model,
|
|
||||||
quantize: bool = False,
|
|
||||||
opset_version: int = 14,
|
|
||||||
export_dir: str = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
|
|
||||||
dummy_input = model.export_dummy_inputs()
|
|
||||||
|
|
||||||
verbose = kwargs.get("verbose", False)
|
|
||||||
|
|
||||||
export_name = model.export_name()
|
|
||||||
model_path = os.path.join(export_dir, export_name)
|
|
||||||
torch.onnx.export(
|
|
||||||
model,
|
|
||||||
dummy_input,
|
|
||||||
model_path,
|
|
||||||
verbose=verbose,
|
|
||||||
opset_version=opset_version,
|
|
||||||
input_names=model.export_input_names(),
|
|
||||||
output_names=model.export_output_names(),
|
|
||||||
dynamic_axes=model.export_dynamic_axes(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if quantize:
|
|
||||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
||||||
import onnx
|
|
||||||
|
|
||||||
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
|
|
||||||
if not os.path.exists(quant_model_path):
|
|
||||||
onnx_model = onnx.load(model_path)
|
|
||||||
nodes = [n.name for n in onnx_model.graph.node]
|
|
||||||
nodes_to_exclude = [
|
|
||||||
m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
|
|
||||||
]
|
|
||||||
quantize_dynamic(
|
|
||||||
model_input=model_path,
|
|
||||||
model_output=quant_model_path,
|
|
||||||
op_types_to_quantize=["MatMul"],
|
|
||||||
per_channel=True,
|
|
||||||
reduce_range=False,
|
|
||||||
weight_type=QuantType.QUInt8,
|
|
||||||
nodes_to_exclude=nodes_to_exclude,
|
|
||||||
)
|
|
||||||
@@ -1,433 +0,0 @@
|
|||||||
# -*- encoding: utf-8 -*-
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
|
||||||
import copy
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import kaldi_native_fbank as knf
|
|
||||||
|
|
||||||
root_dir = Path(__file__).resolve().parent
|
|
||||||
|
|
||||||
logger_initialized = {}
|
|
||||||
|
|
||||||
|
|
||||||
class WavFrontend:
|
|
||||||
"""Conventional frontend structure for ASR."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
cmvn_file: str = None,
|
|
||||||
fs: int = 16000,
|
|
||||||
window: str = "hamming",
|
|
||||||
n_mels: int = 80,
|
|
||||||
frame_length: int = 25,
|
|
||||||
frame_shift: int = 10,
|
|
||||||
lfr_m: int = 1,
|
|
||||||
lfr_n: int = 1,
|
|
||||||
dither: float = 1.0,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
|
|
||||||
opts = knf.FbankOptions()
|
|
||||||
opts.frame_opts.samp_freq = fs
|
|
||||||
opts.frame_opts.dither = dither
|
|
||||||
opts.frame_opts.window_type = window
|
|
||||||
opts.frame_opts.frame_shift_ms = float(frame_shift)
|
|
||||||
opts.frame_opts.frame_length_ms = float(frame_length)
|
|
||||||
opts.mel_opts.num_bins = n_mels
|
|
||||||
opts.energy_floor = 0
|
|
||||||
opts.frame_opts.snip_edges = True
|
|
||||||
opts.mel_opts.debug_mel = False
|
|
||||||
self.opts = opts
|
|
||||||
|
|
||||||
self.lfr_m = lfr_m
|
|
||||||
self.lfr_n = lfr_n
|
|
||||||
self.cmvn_file = cmvn_file
|
|
||||||
|
|
||||||
if self.cmvn_file:
|
|
||||||
self.cmvn = self.load_cmvn()
|
|
||||||
self.fbank_fn = None
|
|
||||||
self.fbank_beg_idx = 0
|
|
||||||
self.reset_status()
|
|
||||||
|
|
||||||
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
waveform = waveform * (1 << 15)
|
|
||||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
|
||||||
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
|
||||||
frames = self.fbank_fn.num_frames_ready
|
|
||||||
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
|
||||||
for i in range(frames):
|
|
||||||
mat[i, :] = self.fbank_fn.get_frame(i)
|
|
||||||
feat = mat.astype(np.float32)
|
|
||||||
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
|
||||||
return feat, feat_len
|
|
||||||
|
|
||||||
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
waveform = waveform * (1 << 15)
|
|
||||||
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
|
||||||
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
|
||||||
frames = self.fbank_fn.num_frames_ready
|
|
||||||
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
|
||||||
for i in range(self.fbank_beg_idx, frames):
|
|
||||||
mat[i, :] = self.fbank_fn.get_frame(i)
|
|
||||||
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
|
|
||||||
feat = mat.astype(np.float32)
|
|
||||||
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
|
||||||
return feat, feat_len
|
|
||||||
|
|
||||||
def reset_status(self):
|
|
||||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
|
||||||
self.fbank_beg_idx = 0
|
|
||||||
|
|
||||||
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
|
||||||
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
|
|
||||||
|
|
||||||
if self.cmvn_file:
|
|
||||||
feat = self.apply_cmvn(feat)
|
|
||||||
|
|
||||||
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
|
||||||
return feat, feat_len
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
|
|
||||||
LFR_inputs = []
|
|
||||||
|
|
||||||
T = inputs.shape[0]
|
|
||||||
T_lfr = int(np.ceil(T / lfr_n))
|
|
||||||
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
|
|
||||||
inputs = np.vstack((left_padding, inputs))
|
|
||||||
T = T + (lfr_m - 1) // 2
|
|
||||||
for i in range(T_lfr):
|
|
||||||
if lfr_m <= T - i * lfr_n:
|
|
||||||
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
|
||||||
else:
|
|
||||||
# process last LFR frame
|
|
||||||
num_padding = lfr_m - (T - i * lfr_n)
|
|
||||||
frame = inputs[i * lfr_n :].reshape(-1)
|
|
||||||
for _ in range(num_padding):
|
|
||||||
frame = np.hstack((frame, inputs[-1]))
|
|
||||||
|
|
||||||
LFR_inputs.append(frame)
|
|
||||||
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
|
|
||||||
return LFR_outputs
|
|
||||||
|
|
||||||
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Apply CMVN with mvn data
|
|
||||||
"""
|
|
||||||
frame, dim = inputs.shape
|
|
||||||
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
|
|
||||||
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
|
|
||||||
inputs = (inputs + means) * vars
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def load_cmvn(
|
|
||||||
self,
|
|
||||||
) -> np.ndarray:
|
|
||||||
with open(self.cmvn_file, "r", encoding="utf-8") as f:
|
|
||||||
lines = f.readlines()
|
|
||||||
|
|
||||||
means_list = []
|
|
||||||
vars_list = []
|
|
||||||
for i in range(len(lines)):
|
|
||||||
line_item = lines[i].split()
|
|
||||||
if line_item[0] == "<AddShift>":
|
|
||||||
line_item = lines[i + 1].split()
|
|
||||||
if line_item[0] == "<LearnRateCoef>":
|
|
||||||
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
|
||||||
means_list = list(add_shift_line)
|
|
||||||
continue
|
|
||||||
elif line_item[0] == "<Rescale>":
|
|
||||||
line_item = lines[i + 1].split()
|
|
||||||
if line_item[0] == "<LearnRateCoef>":
|
|
||||||
rescale_line = line_item[3 : (len(line_item) - 1)]
|
|
||||||
vars_list = list(rescale_line)
|
|
||||||
continue
|
|
||||||
|
|
||||||
means = np.array(means_list).astype(np.float64)
|
|
||||||
vars = np.array(vars_list).astype(np.float64)
|
|
||||||
cmvn = np.array([means, vars])
|
|
||||||
return cmvn
|
|
||||||
|
|
||||||
|
|
||||||
class WavFrontendOnline(WavFrontend):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
|
||||||
# add variables
|
|
||||||
self.frame_sample_length = int(
|
|
||||||
self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
|
|
||||||
)
|
|
||||||
self.frame_shift_sample_length = int(
|
|
||||||
self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
|
|
||||||
)
|
|
||||||
self.waveform = None
|
|
||||||
self.reserve_waveforms = None
|
|
||||||
self.input_cache = None
|
|
||||||
self.lfr_splice_cache = []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
# inputs has catted the cache
|
|
||||||
def apply_lfr(
|
|
||||||
inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, int]:
|
|
||||||
"""
|
|
||||||
Apply lfr with data
|
|
||||||
"""
|
|
||||||
|
|
||||||
LFR_inputs = []
|
|
||||||
T = inputs.shape[0] # include the right context
|
|
||||||
T_lfr = int(
|
|
||||||
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
|
|
||||||
) # minus the right context: (lfr_m - 1) // 2
|
|
||||||
splice_idx = T_lfr
|
|
||||||
for i in range(T_lfr):
|
|
||||||
if lfr_m <= T - i * lfr_n:
|
|
||||||
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
|
||||||
else: # process last LFR frame
|
|
||||||
if is_final:
|
|
||||||
num_padding = lfr_m - (T - i * lfr_n)
|
|
||||||
frame = (inputs[i * lfr_n :]).reshape(-1)
|
|
||||||
for _ in range(num_padding):
|
|
||||||
frame = np.hstack((frame, inputs[-1]))
|
|
||||||
LFR_inputs.append(frame)
|
|
||||||
else:
|
|
||||||
# update splice_idx and break the circle
|
|
||||||
splice_idx = i
|
|
||||||
break
|
|
||||||
splice_idx = min(T - 1, splice_idx * lfr_n)
|
|
||||||
lfr_splice_cache = inputs[splice_idx:, :]
|
|
||||||
LFR_outputs = np.vstack(LFR_inputs)
|
|
||||||
return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def compute_frame_num(
|
|
||||||
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
|
|
||||||
) -> int:
|
|
||||||
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
|
|
||||||
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
|
||||||
|
|
||||||
def fbank(
|
|
||||||
self, input: np.ndarray, input_lengths: np.ndarray
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
|
||||||
batch_size = input.shape[0]
|
|
||||||
if self.input_cache is None:
|
|
||||||
self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
|
|
||||||
input = np.concatenate((self.input_cache, input), axis=1)
|
|
||||||
frame_num = self.compute_frame_num(
|
|
||||||
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
|
|
||||||
)
|
|
||||||
# update self.in_cache
|
|
||||||
self.input_cache = input[
|
|
||||||
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
|
|
||||||
]
|
|
||||||
waveforms = np.empty(0, dtype=np.float32)
|
|
||||||
feats_pad = np.empty(0, dtype=np.float32)
|
|
||||||
feats_lens = np.empty(0, dtype=np.int32)
|
|
||||||
if frame_num:
|
|
||||||
waveforms = []
|
|
||||||
feats = []
|
|
||||||
feats_lens = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
waveform = input[i]
|
|
||||||
waveforms.append(
|
|
||||||
waveform[
|
|
||||||
: (
|
|
||||||
(frame_num - 1) * self.frame_shift_sample_length
|
|
||||||
+ self.frame_sample_length
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
waveform = waveform * (1 << 15)
|
|
||||||
|
|
||||||
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
|
||||||
frames = self.fbank_fn.num_frames_ready
|
|
||||||
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
|
||||||
for i in range(frames):
|
|
||||||
mat[i, :] = self.fbank_fn.get_frame(i)
|
|
||||||
feat = mat.astype(np.float32)
|
|
||||||
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
|
||||||
feats.append(feat)
|
|
||||||
feats_lens.append(feat_len)
|
|
||||||
|
|
||||||
waveforms = np.stack(waveforms)
|
|
||||||
feats_lens = np.array(feats_lens)
|
|
||||||
feats_pad = np.array(feats)
|
|
||||||
self.fbanks = feats_pad
|
|
||||||
self.fbanks_lens = copy.deepcopy(feats_lens)
|
|
||||||
return waveforms, feats_pad, feats_lens
|
|
||||||
|
|
||||||
def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
return self.fbanks, self.fbanks_lens
|
|
||||||
|
|
||||||
def lfr_cmvn(
|
|
||||||
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
|
|
||||||
batch_size = input.shape[0]
|
|
||||||
feats = []
|
|
||||||
feats_lens = []
|
|
||||||
lfr_splice_frame_idxs = []
|
|
||||||
for i in range(batch_size):
|
|
||||||
mat = input[i, : input_lengths[i], :]
|
|
||||||
lfr_splice_frame_idx = -1
|
|
||||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
|
||||||
# update self.lfr_splice_cache in self.apply_lfr
|
|
||||||
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
|
|
||||||
mat, self.lfr_m, self.lfr_n, is_final
|
|
||||||
)
|
|
||||||
if self.cmvn_file is not None:
|
|
||||||
mat = self.apply_cmvn(mat)
|
|
||||||
feat_length = mat.shape[0]
|
|
||||||
feats.append(mat)
|
|
||||||
feats_lens.append(feat_length)
|
|
||||||
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
|
||||||
|
|
||||||
feats_lens = np.array(feats_lens)
|
|
||||||
feats_pad = np.array(feats)
|
|
||||||
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
|
||||||
|
|
||||||
def extract_fbank(
|
|
||||||
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
batch_size = input.shape[0]
|
|
||||||
assert (
|
|
||||||
batch_size == 1
|
|
||||||
), "we support to extract feature online only when the batch size is equal to 1 now"
|
|
||||||
waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
|
|
||||||
if feats.shape[0]:
|
|
||||||
self.waveforms = (
|
|
||||||
waveforms
|
|
||||||
if self.reserve_waveforms is None
|
|
||||||
else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
|
|
||||||
)
|
|
||||||
if not self.lfr_splice_cache:
|
|
||||||
for i in range(batch_size):
|
|
||||||
self.lfr_splice_cache.append(
|
|
||||||
np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
|
||||||
lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
|
|
||||||
feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
|
|
||||||
feats_lengths += lfr_splice_cache_np[0].shape[0]
|
|
||||||
frame_from_waveforms = int(
|
|
||||||
(self.waveforms.shape[1] - self.frame_sample_length)
|
|
||||||
/ self.frame_shift_sample_length
|
|
||||||
+ 1
|
|
||||||
)
|
|
||||||
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
|
||||||
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
|
|
||||||
feats, feats_lengths, is_final
|
|
||||||
)
|
|
||||||
if self.lfr_m == 1:
|
|
||||||
self.reserve_waveforms = None
|
|
||||||
else:
|
|
||||||
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
|
||||||
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
|
||||||
# print('frame_frame: ' + str(frame_from_waveforms))
|
|
||||||
self.reserve_waveforms = self.waveforms[
|
|
||||||
:,
|
|
||||||
reserve_frame_idx
|
|
||||||
* self.frame_shift_sample_length : frame_from_waveforms
|
|
||||||
* self.frame_shift_sample_length,
|
|
||||||
]
|
|
||||||
sample_length = (
|
|
||||||
frame_from_waveforms - 1
|
|
||||||
) * self.frame_shift_sample_length + self.frame_sample_length
|
|
||||||
self.waveforms = self.waveforms[:, :sample_length]
|
|
||||||
else:
|
|
||||||
# update self.reserve_waveforms and self.lfr_splice_cache
|
|
||||||
self.reserve_waveforms = self.waveforms[
|
|
||||||
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
|
|
||||||
]
|
|
||||||
for i in range(batch_size):
|
|
||||||
self.lfr_splice_cache[i] = np.concatenate(
|
|
||||||
(self.lfr_splice_cache[i], feats[i]), axis=0
|
|
||||||
)
|
|
||||||
return np.empty(0, dtype=np.float32), feats_lengths
|
|
||||||
else:
|
|
||||||
if is_final:
|
|
||||||
self.waveforms = (
|
|
||||||
waveforms if self.reserve_waveforms is None else self.reserve_waveforms
|
|
||||||
)
|
|
||||||
feats = np.stack(self.lfr_splice_cache)
|
|
||||||
feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
|
|
||||||
feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
|
|
||||||
if is_final:
|
|
||||||
self.cache_reset()
|
|
||||||
return feats, feats_lengths
|
|
||||||
|
|
||||||
def get_waveforms(self):
|
|
||||||
return self.waveforms
|
|
||||||
|
|
||||||
def cache_reset(self):
|
|
||||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
|
||||||
self.reserve_waveforms = None
|
|
||||||
self.input_cache = None
|
|
||||||
self.lfr_splice_cache = []
|
|
||||||
|
|
||||||
|
|
||||||
def load_bytes(input):
|
|
||||||
middle_data = np.frombuffer(input, dtype=np.int16)
|
|
||||||
middle_data = np.asarray(middle_data)
|
|
||||||
if middle_data.dtype.kind not in "iu":
|
|
||||||
raise TypeError("'middle_data' must be an array of integers")
|
|
||||||
dtype = np.dtype("float32")
|
|
||||||
if dtype.kind != "f":
|
|
||||||
raise TypeError("'dtype' must be a floating point type")
|
|
||||||
|
|
||||||
i = np.iinfo(middle_data.dtype)
|
|
||||||
abs_max = 2 ** (i.bits - 1)
|
|
||||||
offset = i.min + abs_max
|
|
||||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
|
||||||
return array
|
|
||||||
|
|
||||||
|
|
||||||
class SinusoidalPositionEncoderOnline:
|
|
||||||
"""Streaming Positional encoding."""
|
|
||||||
|
|
||||||
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
|
|
||||||
batch_size = positions.shape[0]
|
|
||||||
positions = positions.astype(dtype)
|
|
||||||
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
|
|
||||||
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
|
|
||||||
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
|
|
||||||
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
|
|
||||||
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
|
|
||||||
return encoding.astype(dtype)
|
|
||||||
|
|
||||||
def forward(self, x, start_idx=0):
|
|
||||||
batch_size, timesteps, input_dim = x.shape
|
|
||||||
positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
|
|
||||||
position_encoding = self.encode(positions, input_dim, x.dtype)
|
|
||||||
|
|
||||||
return x + position_encoding[:, start_idx : start_idx + timesteps]
|
|
||||||
|
|
||||||
|
|
||||||
def test():
|
|
||||||
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
|
|
||||||
import librosa
|
|
||||||
|
|
||||||
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
|
|
||||||
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
|
|
||||||
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
|
|
||||||
|
|
||||||
config = read_yaml(config_file)
|
|
||||||
waveform, _ = librosa.load(path, sr=None)
|
|
||||||
frontend = WavFrontend(
|
|
||||||
cmvn_file=cmvn_file,
|
|
||||||
**config["frontend_conf"],
|
|
||||||
)
|
|
||||||
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
|
|
||||||
feat, feat_len = frontend.lfr_cmvn(
|
|
||||||
speech
|
|
||||||
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
|
|
||||||
|
|
||||||
frontend.reset_status() # clear cache
|
|
||||||
return feat, feat_len
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test()
|
|
||||||
@@ -1,395 +0,0 @@
|
|||||||
# -*- encoding: utf-8 -*-
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
|
||||||
|
|
||||||
import re
|
|
||||||
import numpy as np
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
try:
|
|
||||||
from onnxruntime import (
|
|
||||||
GraphOptimizationLevel,
|
|
||||||
InferenceSession,
|
|
||||||
SessionOptions,
|
|
||||||
get_available_providers,
|
|
||||||
get_device,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
print("please pip3 install onnxruntime")
|
|
||||||
import jieba
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
root_dir = Path(__file__).resolve().parent
|
|
||||||
|
|
||||||
logger_initialized = {}
|
|
||||||
|
|
||||||
|
|
||||||
def pad_list(xs, pad_value, max_len=None):
|
|
||||||
n_batch = len(xs)
|
|
||||||
if max_len is None:
|
|
||||||
max_len = max(x.size(0) for x in xs)
|
|
||||||
# pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
|
||||||
# numpy format
|
|
||||||
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
|
|
||||||
for i in range(n_batch):
|
|
||||||
pad[i, : xs[i].shape[0]] = xs[i]
|
|
||||||
|
|
||||||
return pad
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
|
||||||
if length_dim == 0:
|
|
||||||
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
|
||||||
|
|
||||||
if not isinstance(lengths, list):
|
|
||||||
lengths = lengths.tolist()
|
|
||||||
bs = int(len(lengths))
|
|
||||||
if maxlen is None:
|
|
||||||
if xs is None:
|
|
||||||
maxlen = int(max(lengths))
|
|
||||||
else:
|
|
||||||
maxlen = xs.size(length_dim)
|
|
||||||
else:
|
|
||||||
assert xs is None
|
|
||||||
assert maxlen >= int(max(lengths))
|
|
||||||
|
|
||||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
|
||||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
|
||||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
|
||||||
mask = seq_range_expand >= seq_length_expand
|
|
||||||
|
|
||||||
if xs is not None:
|
|
||||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
|
||||||
|
|
||||||
if length_dim < 0:
|
|
||||||
length_dim = xs.dim() + length_dim
|
|
||||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
|
||||||
ind = tuple(
|
|
||||||
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
|
||||||
)
|
|
||||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
|
||||||
return mask
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class TokenIDConverter:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
token_list: Union[List, str],
|
|
||||||
):
|
|
||||||
|
|
||||||
self.token_list = token_list
|
|
||||||
self.unk_symbol = token_list[-1]
|
|
||||||
self.token2id = {v: i for i, v in enumerate(self.token_list)}
|
|
||||||
self.unk_id = self.token2id[self.unk_symbol]
|
|
||||||
|
|
||||||
def get_num_vocabulary_size(self) -> int:
|
|
||||||
return len(self.token_list)
|
|
||||||
|
|
||||||
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
|
||||||
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
|
||||||
raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
|
||||||
return [self.token_list[i] for i in integers]
|
|
||||||
|
|
||||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
|
||||||
|
|
||||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
|
||||||
|
|
||||||
|
|
||||||
class CharTokenizer:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
symbol_value: Union[Path, str, Iterable[str]] = None,
|
|
||||||
space_symbol: str = "<space>",
|
|
||||||
remove_non_linguistic_symbols: bool = False,
|
|
||||||
):
|
|
||||||
|
|
||||||
self.space_symbol = space_symbol
|
|
||||||
self.non_linguistic_symbols = self.load_symbols(symbol_value)
|
|
||||||
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
|
|
||||||
if value is None:
|
|
||||||
return set()
|
|
||||||
|
|
||||||
if isinstance(value, Iterable[str]):
|
|
||||||
return set(value)
|
|
||||||
|
|
||||||
file_path = Path(value)
|
|
||||||
if not file_path.exists():
|
|
||||||
logging.warning("%s doesn't exist.", file_path)
|
|
||||||
return set()
|
|
||||||
|
|
||||||
with file_path.open("r", encoding="utf-8") as f:
|
|
||||||
return set(line.rstrip() for line in f)
|
|
||||||
|
|
||||||
def text2tokens(self, line: Union[str, list]) -> List[str]:
|
|
||||||
tokens = []
|
|
||||||
while len(line) != 0:
|
|
||||||
for w in self.non_linguistic_symbols:
|
|
||||||
if line.startswith(w):
|
|
||||||
if not self.remove_non_linguistic_symbols:
|
|
||||||
tokens.append(line[: len(w)])
|
|
||||||
line = line[len(w) :]
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
t = line[0]
|
|
||||||
if t == " ":
|
|
||||||
t = "<space>"
|
|
||||||
tokens.append(t)
|
|
||||||
line = line[1:]
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
|
||||||
tokens = [t if t != self.space_symbol else " " for t in tokens]
|
|
||||||
return "".join(tokens)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (
|
|
||||||
f"{self.__class__.__name__}("
|
|
||||||
f'space_symbol="{self.space_symbol}"'
|
|
||||||
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Hypothesis(NamedTuple):
|
|
||||||
"""Hypothesis data type."""
|
|
||||||
|
|
||||||
yseq: np.ndarray
|
|
||||||
score: Union[float, np.ndarray] = 0
|
|
||||||
scores: Dict[str, Union[float, np.ndarray]] = dict()
|
|
||||||
states: Dict[str, Any] = dict()
|
|
||||||
|
|
||||||
def asdict(self) -> dict:
|
|
||||||
"""Convert data to JSON-friendly dict."""
|
|
||||||
return self._replace(
|
|
||||||
yseq=self.yseq.tolist(),
|
|
||||||
score=float(self.score),
|
|
||||||
scores={k: float(v) for k, v in self.scores.items()},
|
|
||||||
)._asdict()
|
|
||||||
|
|
||||||
|
|
||||||
class TokenIDConverterError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ONNXRuntimeError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OrtInferSession:
|
|
||||||
def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
|
|
||||||
device_id = str(device_id)
|
|
||||||
sess_opt = SessionOptions()
|
|
||||||
sess_opt.intra_op_num_threads = intra_op_num_threads
|
|
||||||
sess_opt.log_severity_level = 4
|
|
||||||
sess_opt.enable_cpu_mem_arena = False
|
|
||||||
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
||||||
|
|
||||||
cuda_ep = "CUDAExecutionProvider"
|
|
||||||
cuda_provider_options = {
|
|
||||||
"device_id": device_id,
|
|
||||||
"arena_extend_strategy": "kNextPowerOfTwo",
|
|
||||||
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
|
||||||
"do_copy_in_default_stream": "true",
|
|
||||||
}
|
|
||||||
cpu_ep = "CPUExecutionProvider"
|
|
||||||
cpu_provider_options = {
|
|
||||||
"arena_extend_strategy": "kSameAsRequested",
|
|
||||||
}
|
|
||||||
|
|
||||||
EP_list = []
|
|
||||||
if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers():
|
|
||||||
EP_list = [(cuda_ep, cuda_provider_options)]
|
|
||||||
EP_list.append((cpu_ep, cpu_provider_options))
|
|
||||||
|
|
||||||
self._verify_model(model_file)
|
|
||||||
self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list)
|
|
||||||
|
|
||||||
if device_id != "-1" and cuda_ep not in self.session.get_providers():
|
|
||||||
warnings.warn(
|
|
||||||
f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
|
|
||||||
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
|
|
||||||
"you can check their relations from the offical web site: "
|
|
||||||
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
|
|
||||||
RuntimeWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
|
|
||||||
input_dict = dict(zip(self.get_input_names(), input_content))
|
|
||||||
try:
|
|
||||||
return self.session.run(self.get_output_names(), input_dict)
|
|
||||||
except Exception as e:
|
|
||||||
raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
|
|
||||||
|
|
||||||
def get_input_names(
|
|
||||||
self,
|
|
||||||
):
|
|
||||||
return [v.name for v in self.session.get_inputs()]
|
|
||||||
|
|
||||||
def get_output_names(
|
|
||||||
self,
|
|
||||||
):
|
|
||||||
return [v.name for v in self.session.get_outputs()]
|
|
||||||
|
|
||||||
def get_character_list(self, key: str = "character"):
|
|
||||||
return self.meta_dict[key].splitlines()
|
|
||||||
|
|
||||||
def have_key(self, key: str = "character") -> bool:
|
|
||||||
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
|
||||||
if key in self.meta_dict.keys():
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _verify_model(model_path):
|
|
||||||
model_path = Path(model_path)
|
|
||||||
if not model_path.exists():
|
|
||||||
raise FileNotFoundError(f"{model_path} does not exists.")
|
|
||||||
if not model_path.is_file():
|
|
||||||
raise FileExistsError(f"{model_path} is not a file.")
|
|
||||||
|
|
||||||
|
|
||||||
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
|
||||||
assert word_limit > 1
|
|
||||||
if len(words) <= word_limit:
|
|
||||||
return [words]
|
|
||||||
sentences = []
|
|
||||||
length = len(words)
|
|
||||||
sentence_len = length // word_limit
|
|
||||||
for i in range(sentence_len):
|
|
||||||
sentences.append(words[i * word_limit : (i + 1) * word_limit])
|
|
||||||
if length % word_limit > 0:
|
|
||||||
sentences.append(words[sentence_len * word_limit :])
|
|
||||||
return sentences
|
|
||||||
|
|
||||||
|
|
||||||
def code_mix_split_words(text: str):
|
|
||||||
words = []
|
|
||||||
segs = text.split()
|
|
||||||
for seg in segs:
|
|
||||||
# There is no space in seg.
|
|
||||||
current_word = ""
|
|
||||||
for c in seg:
|
|
||||||
if len(c.encode()) == 1:
|
|
||||||
# This is an ASCII char.
|
|
||||||
current_word += c
|
|
||||||
else:
|
|
||||||
# This is a Chinese char.
|
|
||||||
if len(current_word) > 0:
|
|
||||||
words.append(current_word)
|
|
||||||
current_word = ""
|
|
||||||
words.append(c)
|
|
||||||
if len(current_word) > 0:
|
|
||||||
words.append(current_word)
|
|
||||||
return words
|
|
||||||
|
|
||||||
|
|
||||||
def isEnglish(text: str):
|
|
||||||
if re.search("^[a-zA-Z']+$", text):
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def join_chinese_and_english(input_list):
|
|
||||||
line = ""
|
|
||||||
for token in input_list:
|
|
||||||
if isEnglish(token):
|
|
||||||
line = line + " " + token
|
|
||||||
else:
|
|
||||||
line = line + token
|
|
||||||
|
|
||||||
line = line.strip()
|
|
||||||
return line
|
|
||||||
|
|
||||||
|
|
||||||
def code_mix_split_words_jieba(seg_dict_file: str):
|
|
||||||
jieba.load_userdict(seg_dict_file)
|
|
||||||
|
|
||||||
def _fn(text: str):
|
|
||||||
input_list = text.split()
|
|
||||||
token_list_all = []
|
|
||||||
langauge_list = []
|
|
||||||
token_list_tmp = []
|
|
||||||
language_flag = None
|
|
||||||
for token in input_list:
|
|
||||||
if isEnglish(token) and language_flag == "Chinese":
|
|
||||||
token_list_all.append(token_list_tmp)
|
|
||||||
langauge_list.append("Chinese")
|
|
||||||
token_list_tmp = []
|
|
||||||
elif not isEnglish(token) and language_flag == "English":
|
|
||||||
token_list_all.append(token_list_tmp)
|
|
||||||
langauge_list.append("English")
|
|
||||||
token_list_tmp = []
|
|
||||||
|
|
||||||
token_list_tmp.append(token)
|
|
||||||
|
|
||||||
if isEnglish(token):
|
|
||||||
language_flag = "English"
|
|
||||||
else:
|
|
||||||
language_flag = "Chinese"
|
|
||||||
|
|
||||||
if token_list_tmp:
|
|
||||||
token_list_all.append(token_list_tmp)
|
|
||||||
langauge_list.append(language_flag)
|
|
||||||
|
|
||||||
result_list = []
|
|
||||||
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
|
|
||||||
if language_flag == "English":
|
|
||||||
result_list.extend(token_list_tmp)
|
|
||||||
else:
|
|
||||||
seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
|
|
||||||
result_list.extend(seg_list)
|
|
||||||
|
|
||||||
return result_list
|
|
||||||
|
|
||||||
return _fn
|
|
||||||
|
|
||||||
|
|
||||||
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
|
|
||||||
if not Path(yaml_path).exists():
|
|
||||||
raise FileExistsError(f"The {yaml_path} does not exist.")
|
|
||||||
|
|
||||||
with open(str(yaml_path), "rb") as f:
|
|
||||||
data = yaml.load(f, Loader=yaml.Loader)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache()
|
|
||||||
def get_logger(name="funasr_onnx"):
|
|
||||||
"""Initialize and get a logger by name.
|
|
||||||
If the logger has not been initialized, this method will initialize the
|
|
||||||
logger by adding one or two handlers, otherwise the initialized logger will
|
|
||||||
be directly returned. During initialization, a StreamHandler will always be
|
|
||||||
added.
|
|
||||||
Args:
|
|
||||||
name (str): Logger name.
|
|
||||||
Returns:
|
|
||||||
logging.Logger: The expected logger.
|
|
||||||
"""
|
|
||||||
logger = logging.getLogger(name)
|
|
||||||
if name in logger_initialized:
|
|
||||||
return logger
|
|
||||||
|
|
||||||
for logger_name in logger_initialized:
|
|
||||||
if name.startswith(logger_name):
|
|
||||||
return logger
|
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
|
||||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
|
|
||||||
)
|
|
||||||
|
|
||||||
sh = logging.StreamHandler()
|
|
||||||
sh.setFormatter(formatter)
|
|
||||||
logger.addHandler(sh)
|
|
||||||
logger_initialized[name] = True
|
|
||||||
logger.propagate = False
|
|
||||||
logging.basicConfig(level=logging.ERROR)
|
|
||||||
return logger
|
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# -*- encoding: utf-8 -*-
|
|
||||||
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
|
|
||||||
# MIT License (https://opensource.org/licenses/MIT)
|
|
||||||
|
|
||||||
import os.path
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Union, Tuple
|
|
||||||
import torch
|
|
||||||
import librosa
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from utils.infer_utils import (
|
|
||||||
CharTokenizer,
|
|
||||||
Hypothesis,
|
|
||||||
ONNXRuntimeError,
|
|
||||||
OrtInferSession,
|
|
||||||
TokenIDConverter,
|
|
||||||
get_logger,
|
|
||||||
read_yaml,
|
|
||||||
)
|
|
||||||
from utils.frontend import WavFrontend
|
|
||||||
from utils.infer_utils import pad_list
|
|
||||||
|
|
||||||
logging = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class SenseVoiceSmallONNX:
|
|
||||||
"""
|
|
||||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
|
||||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
|
||||||
https://arxiv.org/abs/2206.08317
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_dir: Union[str, Path] = None,
|
|
||||||
batch_size: int = 1,
|
|
||||||
device_id: Union[str, int] = "-1",
|
|
||||||
plot_timestamp_to: str = "",
|
|
||||||
quantize: bool = False,
|
|
||||||
intra_op_num_threads: int = 4,
|
|
||||||
cache_dir: str = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if quantize:
|
|
||||||
model_file = os.path.join(model_dir, "model_quant.onnx")
|
|
||||||
else:
|
|
||||||
model_file = os.path.join(model_dir, "model.onnx")
|
|
||||||
|
|
||||||
config_file = os.path.join(model_dir, "config.yaml")
|
|
||||||
cmvn_file = os.path.join(model_dir, "am.mvn")
|
|
||||||
config = read_yaml(config_file)
|
|
||||||
# token_list = os.path.join(model_dir, "tokens.json")
|
|
||||||
# with open(token_list, "r", encoding="utf-8") as f:
|
|
||||||
# token_list = json.load(f)
|
|
||||||
|
|
||||||
# self.converter = TokenIDConverter(token_list)
|
|
||||||
self.tokenizer = CharTokenizer()
|
|
||||||
config["frontend_conf"]['cmvn_file'] = cmvn_file
|
|
||||||
self.frontend = WavFrontend(**config["frontend_conf"])
|
|
||||||
self.ort_infer = OrtInferSession(
|
|
||||||
model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
|
||||||
)
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.blank_id = 0
|
|
||||||
|
|
||||||
def __call__(self,
|
|
||||||
wav_content: Union[str, np.ndarray, List[str]],
|
|
||||||
language: List,
|
|
||||||
textnorm: List,
|
|
||||||
tokenizer=None,
|
|
||||||
**kwargs) -> List:
|
|
||||||
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
|
||||||
waveform_nums = len(waveform_list)
|
|
||||||
asr_res = []
|
|
||||||
for beg_idx in range(0, waveform_nums, self.batch_size):
|
|
||||||
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
|
||||||
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
|
||||||
ctc_logits, encoder_out_lens = self.infer(feats,
|
|
||||||
feats_len,
|
|
||||||
np.array(language, dtype=np.int32),
|
|
||||||
np.array(textnorm, dtype=np.int32)
|
|
||||||
)
|
|
||||||
# back to torch.Tensor
|
|
||||||
ctc_logits = torch.from_numpy(ctc_logits).float()
|
|
||||||
# support batch_size=1 only currently
|
|
||||||
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
|
|
||||||
yseq = x.argmax(dim=-1)
|
|
||||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
|
||||||
|
|
||||||
mask = yseq != self.blank_id
|
|
||||||
token_int = yseq[mask].tolist()
|
|
||||||
|
|
||||||
if tokenizer is not None:
|
|
||||||
asr_res.append(tokenizer.tokens2text(token_int))
|
|
||||||
else:
|
|
||||||
asr_res.append(token_int)
|
|
||||||
return asr_res
|
|
||||||
|
|
||||||
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
|
|
||||||
def load_wav(path: str) -> np.ndarray:
|
|
||||||
waveform, _ = librosa.load(path, sr=fs)
|
|
||||||
return waveform
|
|
||||||
|
|
||||||
if isinstance(wav_content, np.ndarray):
|
|
||||||
return [wav_content]
|
|
||||||
|
|
||||||
if isinstance(wav_content, str):
|
|
||||||
return [load_wav(wav_content)]
|
|
||||||
|
|
||||||
if isinstance(wav_content, list):
|
|
||||||
return [load_wav(path) for path in wav_content]
|
|
||||||
|
|
||||||
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
|
|
||||||
|
|
||||||
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
feats, feats_len = [], []
|
|
||||||
for waveform in waveform_list:
|
|
||||||
speech, _ = self.frontend.fbank(waveform)
|
|
||||||
feat, feat_len = self.frontend.lfr_cmvn(speech)
|
|
||||||
feats.append(feat)
|
|
||||||
feats_len.append(feat_len)
|
|
||||||
|
|
||||||
feats = self.pad_feats(feats, np.max(feats_len))
|
|
||||||
feats_len = np.array(feats_len).astype(np.int32)
|
|
||||||
return feats, feats_len
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
|
|
||||||
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
|
|
||||||
pad_width = ((0, max_feat_len - cur_len), (0, 0))
|
|
||||||
return np.pad(feat, pad_width, "constant", constant_values=0)
|
|
||||||
|
|
||||||
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
|
|
||||||
feats = np.array(feat_res).astype(np.float32)
|
|
||||||
return feats
|
|
||||||
|
|
||||||
def infer(self,
|
|
||||||
feats: np.ndarray,
|
|
||||||
feats_len: np.ndarray,
|
|
||||||
language: np.ndarray,
|
|
||||||
textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
outputs = self.ort_infer([feats, feats_len, language, textnorm])
|
|
||||||
return outputs
|
|
||||||
Reference in New Issue
Block a user