You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
PaddleSpeech/paddlespeech/kws/models/loss.py

83 lines
3.0 KiB

# Copyright (c) 2021 Binbin Zhang
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from wekws(https://github.com/wenet-e2e/wekws)
import paddle
def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor:
batch_size = lengths.shape[0]
max_len = int(lengths.max().item())
seq = paddle.arange(max_len, dtype=paddle.int64)
seq = seq.expand((batch_size, max_len))
return seq >= lengths.unsqueeze(1)
def fill_mask_elements(condition: paddle.Tensor, value: float,
x: paddle.Tensor) -> paddle.Tensor:
assert condition.shape == x.shape
values = paddle.ones_like(x, dtype=x.dtype) * value
return paddle.where(condition, values, x)
def max_pooling_loss(logits: paddle.Tensor,
target: paddle.Tensor,
lengths: paddle.Tensor,
min_duration: int=0):
mask = padding_mask(lengths)
num_utts = logits.shape[0]
num_keywords = logits.shape[2]
loss = 0.0
for i in range(num_utts):
for j in range(num_keywords):
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
if target[i] == j:
# For the keyword, do max-polling
prob = logits[i, :, j]
m = mask[i]
if min_duration > 0:
m[:min_duration] = True
prob = fill_mask_elements(m, 0.0, prob)
prob = paddle.clip(prob, 1e-8, 1.0)
max_prob = prob.max()
loss += -paddle.log(max_prob)
else:
# For other keywords or filler, do min-polling
prob = 1 - logits[i, :, j]
prob = fill_mask_elements(mask[i], 1.0, prob)
prob = paddle.clip(prob, 1e-8, 1.0)
min_prob = prob.min()
loss += -paddle.log(min_prob)
loss = loss / num_utts
# Compute accuracy of current batch
mask = mask.unsqueeze(-1)
logits = fill_mask_elements(mask, 0.0, logits)
max_logits = logits.max(1)
num_correct = 0
for i in range(num_utts):
max_p = max_logits[i].max(0).item()
idx = max_logits[i].argmax(0).item()
# Predict correct as the i'th keyword
if max_p > 0.5 and idx == target[i].item():
num_correct += 1
# Predict correct as the filler, filler id < 0
if max_p < 0.5 and target[i].item() < 0:
num_correct += 1
acc = num_correct / num_utts
# acc = 0.0
return loss, num_correct, acc