|
|
|
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
# Modified from wekws(https://github.com/wenet-e2e/wekws)
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
|
|
def padding_mask(lengths: paddle.Tensor) -> paddle.Tensor:
|
|
|
|
batch_size = lengths.shape[0]
|
|
|
|
max_len = int(lengths.max().item())
|
|
|
|
seq = paddle.arange(max_len, dtype=paddle.int64)
|
|
|
|
seq = seq.expand((batch_size, max_len))
|
|
|
|
return seq >= lengths.unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
|
|
def fill_mask_elements(condition: paddle.Tensor, value: float,
|
|
|
|
x: paddle.Tensor) -> paddle.Tensor:
|
|
|
|
assert condition.shape == x.shape
|
|
|
|
values = paddle.ones_like(x, dtype=x.dtype) * value
|
|
|
|
return paddle.where(condition, values, x)
|
|
|
|
|
|
|
|
|
|
|
|
def max_pooling_loss(logits: paddle.Tensor,
|
|
|
|
target: paddle.Tensor,
|
|
|
|
lengths: paddle.Tensor,
|
|
|
|
min_duration: int=0):
|
|
|
|
|
|
|
|
mask = padding_mask(lengths)
|
|
|
|
num_utts = logits.shape[0]
|
|
|
|
num_keywords = logits.shape[2]
|
|
|
|
|
|
|
|
loss = 0.0
|
|
|
|
for i in range(num_utts):
|
|
|
|
for j in range(num_keywords):
|
|
|
|
# Add entropy loss CE = -(t * log(p) + (1 - t) * log(1 - p))
|
|
|
|
if target[i] == j:
|
|
|
|
# For the keyword, do max-polling
|
|
|
|
prob = logits[i, :, j]
|
|
|
|
m = mask[i]
|
|
|
|
if min_duration > 0:
|
|
|
|
m[:min_duration] = True
|
|
|
|
prob = fill_mask_elements(m, 0.0, prob)
|
|
|
|
prob = paddle.clip(prob, 1e-8, 1.0)
|
|
|
|
max_prob = prob.max()
|
|
|
|
loss += -paddle.log(max_prob)
|
|
|
|
else:
|
|
|
|
# For other keywords or filler, do min-polling
|
|
|
|
prob = 1 - logits[i, :, j]
|
|
|
|
prob = fill_mask_elements(mask[i], 1.0, prob)
|
|
|
|
prob = paddle.clip(prob, 1e-8, 1.0)
|
|
|
|
min_prob = prob.min()
|
|
|
|
loss += -paddle.log(min_prob)
|
|
|
|
loss = loss / num_utts
|
|
|
|
|
|
|
|
# Compute accuracy of current batch
|
|
|
|
mask = mask.unsqueeze(-1)
|
|
|
|
logits = fill_mask_elements(mask, 0.0, logits)
|
|
|
|
max_logits = logits.max(1)
|
|
|
|
num_correct = 0
|
|
|
|
for i in range(num_utts):
|
|
|
|
max_p = max_logits[i].max(0).item()
|
|
|
|
idx = max_logits[i].argmax(0).item()
|
|
|
|
# Predict correct as the i'th keyword
|
|
|
|
if max_p > 0.5 and idx == target[i].item():
|
|
|
|
num_correct += 1
|
|
|
|
# Predict correct as the filler, filler id < 0
|
|
|
|
if max_p < 0.5 and target[i].item() < 0:
|
|
|
|
num_correct += 1
|
|
|
|
acc = num_correct / num_utts
|
|
|
|
# acc = 0.0
|
|
|
|
return loss, num_correct, acc
|