Refactor CTC module, add embedding and fix log (#549)
* add acts, refactor ctc, add pos embed * fix export, dataloader time log * fix egs * fix libri readmepull/550/head
parent
00889bfaf2
commit
1539f3e0a3
@ -0,0 +1,238 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle.nn import initializer as I
|
||||||
|
|
||||||
|
from deepspeech.decoders.swig_wrapper import Scorer
|
||||||
|
from deepspeech.decoders.swig_wrapper import ctc_greedy_decoder
|
||||||
|
from deepspeech.decoders.swig_wrapper import ctc_beam_search_decoder_batch
|
||||||
|
from deepspeech.modules.loss import CTCLoss
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ['CTCDecoder']
|
||||||
|
|
||||||
|
|
||||||
|
class CTCDecoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
enc_n_units,
|
||||||
|
odim,
|
||||||
|
blank_id=0,
|
||||||
|
dropout_rate: float=0.0,
|
||||||
|
reduction: bool=True):
|
||||||
|
"""CTC decoder
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enc_n_units ([int]): encoder output dimention
|
||||||
|
vocab_size ([int]): text vocabulary size
|
||||||
|
dropout_rate (float): dropout rate (0.0 ~ 1.0)
|
||||||
|
reduction (bool): reduce the CTC loss into a scalar
|
||||||
|
"""
|
||||||
|
assert check_argument_types()
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.blank_id = blank_id
|
||||||
|
self.odim = odim
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.ctc_lo = nn.Linear(enc_n_units, self.odim)
|
||||||
|
reduction_type = "sum" if reduction else "none"
|
||||||
|
self.criterion = CTCLoss(blank=self.blank_id, reduction=reduction_type)
|
||||||
|
|
||||||
|
# CTCDecoder LM Score handle
|
||||||
|
self._ext_scorer = None
|
||||||
|
|
||||||
|
def forward(self, hs_pad, hlens, ys_pad, ys_lens):
|
||||||
|
"""Calculate CTC loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs_pad (Tensor): batch of padded hidden state sequences (B, Tmax, D)
|
||||||
|
hlens (Tensor): batch of lengths of hidden state sequences (B)
|
||||||
|
ys_pad (Tenosr): batch of padded character id sequence tensor (B, Lmax)
|
||||||
|
ys_lens (Tensor): batch of lengths of character sequence (B)
|
||||||
|
Returns:
|
||||||
|
loss (Tenosr): scalar.
|
||||||
|
"""
|
||||||
|
logits = self.ctc_lo(F.dropout(hs_pad, p=self.dropout_rate))
|
||||||
|
loss = self.criterion(logits, ys_pad, hlens, ys_lens)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def probs(self, eouts: paddle.Tensor, temperature: float=1.0):
|
||||||
|
"""Get CTC probabilities.
|
||||||
|
Args:
|
||||||
|
eouts (FloatTensor): `[B, T, enc_units]`
|
||||||
|
Returns:
|
||||||
|
probs (FloatTensor): `[B, T, odim]`
|
||||||
|
"""
|
||||||
|
return F.softmax(self.ctc_lo(eouts) / temperature, axis=-1)
|
||||||
|
|
||||||
|
def scores(self, eouts: paddle.Tensor, temperature: float=1.0):
|
||||||
|
"""Get log-scale CTC probabilities.
|
||||||
|
Args:
|
||||||
|
eouts (FloatTensor): `[B, T, enc_units]`
|
||||||
|
Returns:
|
||||||
|
log_probs (FloatTensor): `[B, T, odim]`
|
||||||
|
"""
|
||||||
|
return F.log_softmax(self.ctc_lo(eouts) / temperature, axis=-1)
|
||||||
|
|
||||||
|
def log_softmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""log_softmax of frame activations
|
||||||
|
Args:
|
||||||
|
Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: log softmax applied 3d tensor (B, Tmax, odim)
|
||||||
|
"""
|
||||||
|
return self.scores(hs_pad)
|
||||||
|
|
||||||
|
def argmax(self, hs_pad: paddle.Tensor) -> paddle.Tensor:
|
||||||
|
"""argmax of frame activations
|
||||||
|
Args:
|
||||||
|
paddle.Tensor hs_pad: 3d tensor (B, Tmax, eprojs)
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: argmax applied 2d tensor (B, Tmax)
|
||||||
|
"""
|
||||||
|
return paddle.argmax(self.ctc_lo(hs_pad), dim=2)
|
||||||
|
|
||||||
|
def _decode_batch_greedy(self, probs_split, vocab_list):
|
||||||
|
"""Decode by best path for a batch of probs matrix input.
|
||||||
|
:param probs_split: List of 2-D probability matrix, and each consists
|
||||||
|
of prob vectors for one speech utterancce.
|
||||||
|
:param probs_split: List of matrix
|
||||||
|
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||||
|
:type vocab_list: list
|
||||||
|
:return: List of transcription texts.
|
||||||
|
:rtype: List of str
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for i, probs in enumerate(probs_split):
|
||||||
|
output_transcription = ctc_greedy_decoder(
|
||||||
|
probs_seq=probs, vocabulary=vocab_list)
|
||||||
|
results.append(output_transcription)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _init_ext_scorer(self, beam_alpha, beam_beta, language_model_path,
|
||||||
|
vocab_list):
|
||||||
|
"""Initialize the external scorer.
|
||||||
|
:param beam_alpha: Parameter associated with language model.
|
||||||
|
:type beam_alpha: float
|
||||||
|
:param beam_beta: Parameter associated with word count.
|
||||||
|
:type beam_beta: float
|
||||||
|
:param language_model_path: Filepath for language model. If it is
|
||||||
|
empty, the external scorer will be set to
|
||||||
|
None, and the decoding method will be pure
|
||||||
|
beam search without scorer.
|
||||||
|
:type language_model_path: str|None
|
||||||
|
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||||
|
:type vocab_list: list
|
||||||
|
"""
|
||||||
|
# init once
|
||||||
|
if self._ext_scorer != None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if language_model_path != '':
|
||||||
|
logger.info("begin to initialize the external scorer "
|
||||||
|
"for decoding")
|
||||||
|
self._ext_scorer = Scorer(beam_alpha, beam_beta,
|
||||||
|
language_model_path, vocab_list)
|
||||||
|
lm_char_based = self._ext_scorer.is_character_based()
|
||||||
|
lm_max_order = self._ext_scorer.get_max_order()
|
||||||
|
lm_dict_size = self._ext_scorer.get_dict_size()
|
||||||
|
logger.info("language model: "
|
||||||
|
"is_character_based = %d," % lm_char_based +
|
||||||
|
" max_order = %d," % lm_max_order + " dict_size = %d" %
|
||||||
|
lm_dict_size)
|
||||||
|
logger.info("end initializing scorer")
|
||||||
|
else:
|
||||||
|
self._ext_scorer = None
|
||||||
|
logger.info("no language model provided, "
|
||||||
|
"decoding by pure beam search without scorer.")
|
||||||
|
|
||||||
|
def _decode_batch_beam_search(self, probs_split, beam_alpha, beam_beta,
|
||||||
|
beam_size, cutoff_prob, cutoff_top_n,
|
||||||
|
vocab_list, num_processes):
|
||||||
|
"""Decode by beam search for a batch of probs matrix input.
|
||||||
|
:param probs_split: List of 2-D probability matrix, and each consists
|
||||||
|
of prob vectors for one speech utterancce.
|
||||||
|
:param probs_split: List of matrix
|
||||||
|
:param beam_alpha: Parameter associated with language model.
|
||||||
|
:type beam_alpha: float
|
||||||
|
:param beam_beta: Parameter associated with word count.
|
||||||
|
:type beam_beta: float
|
||||||
|
:param beam_size: Width for Beam search.
|
||||||
|
:type beam_size: int
|
||||||
|
:param cutoff_prob: Cutoff probability in pruning,
|
||||||
|
default 1.0, no pruning.
|
||||||
|
:type cutoff_prob: float
|
||||||
|
:param cutoff_top_n: Cutoff number in pruning, only top cutoff_top_n
|
||||||
|
characters with highest probs in vocabulary will be
|
||||||
|
used in beam search, default 40.
|
||||||
|
:type cutoff_top_n: int
|
||||||
|
:param vocab_list: List of tokens in the vocabulary, for decoding.
|
||||||
|
:type vocab_list: list
|
||||||
|
:param num_processes: Number of processes (CPU) for decoder.
|
||||||
|
:type num_processes: int
|
||||||
|
:return: List of transcription texts.
|
||||||
|
:rtype: List of str
|
||||||
|
"""
|
||||||
|
if self._ext_scorer != None:
|
||||||
|
self._ext_scorer.reset_params(beam_alpha, beam_beta)
|
||||||
|
|
||||||
|
# beam search decode
|
||||||
|
num_processes = min(num_processes, len(probs_split))
|
||||||
|
beam_search_results = ctc_beam_search_decoder_batch(
|
||||||
|
probs_split=probs_split,
|
||||||
|
vocabulary=vocab_list,
|
||||||
|
beam_size=beam_size,
|
||||||
|
num_processes=num_processes,
|
||||||
|
ext_scoring_func=self._ext_scorer,
|
||||||
|
cutoff_prob=cutoff_prob,
|
||||||
|
cutoff_top_n=cutoff_top_n)
|
||||||
|
|
||||||
|
results = [result[0][1] for result in beam_search_results]
|
||||||
|
return results
|
||||||
|
|
||||||
|
def init_decode(self, beam_alpha, beam_beta, lang_model_path, vocab_list,
|
||||||
|
decoding_method):
|
||||||
|
if decoding_method == "ctc_beam_search":
|
||||||
|
self._init_ext_scorer(beam_alpha, beam_beta, lang_model_path,
|
||||||
|
vocab_list)
|
||||||
|
|
||||||
|
def decode_probs(self, probs, logits_lens, vocab_list, decoding_method,
|
||||||
|
lang_model_path, beam_alpha, beam_beta, beam_size,
|
||||||
|
cutoff_prob, cutoff_top_n, num_processes):
|
||||||
|
""" probs: activation after softmax
|
||||||
|
logits_len: audio output lens
|
||||||
|
"""
|
||||||
|
probs_split = [probs[i, :l, :] for i, l in enumerate(logits_lens)]
|
||||||
|
if decoding_method == "ctc_greedy":
|
||||||
|
result_transcripts = self._decode_batch_greedy(
|
||||||
|
probs_split=probs_split, vocab_list=vocab_list)
|
||||||
|
elif decoding_method == "ctc_beam_search":
|
||||||
|
result_transcripts = self._decode_batch_beam_search(
|
||||||
|
probs_split=probs_split,
|
||||||
|
beam_alpha=beam_alpha,
|
||||||
|
beam_beta=beam_beta,
|
||||||
|
beam_size=beam_size,
|
||||||
|
cutoff_prob=cutoff_prob,
|
||||||
|
cutoff_top_n=cutoff_top_n,
|
||||||
|
vocab_list=vocab_list,
|
||||||
|
num_processes=num_processes)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Not support: {decoding_method}")
|
||||||
|
return result_transcripts
|
@ -0,0 +1,132 @@
|
|||||||
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Positonal Encoding Module."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle.nn import initializer as I
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["PositionalEncoding", "RelPositionalEncoding"]
|
||||||
|
|
||||||
|
# TODO(Hui Zhang): remove this hack
|
||||||
|
paddle.float32 = 'float32'
|
||||||
|
|
||||||
|
|
||||||
|
class PositionalEncoding(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
d_model: int,
|
||||||
|
dropout_rate: float,
|
||||||
|
max_len: int=5000,
|
||||||
|
reverse: bool=False):
|
||||||
|
"""Positional encoding.
|
||||||
|
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
||||||
|
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
||||||
|
Args:
|
||||||
|
d_model (int): embedding dim.
|
||||||
|
dropout_rate (float): dropout rate.
|
||||||
|
max_len (int, optional): maximum input length. Defaults to 5000.
|
||||||
|
reverse (bool, optional): Not used. Defaults to False.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.d_model = d_model
|
||||||
|
self.max_len = max_len
|
||||||
|
self.xscale = paddle.to_tensor(math.sqrt(self.d_model))
|
||||||
|
self.dropout = nn.Dropout(p=dropout_rate)
|
||||||
|
self.pe = paddle.zeros(self.max_len, self.d_model) #[T,D]
|
||||||
|
|
||||||
|
position = paddle.arange(
|
||||||
|
0, self.max_len, dtype=paddle.float32).unsqueeze(1)
|
||||||
|
div_term = paddle.exp(
|
||||||
|
paddle.arange(0, self.d_model, 2, dtype=paddle.float32) *
|
||||||
|
-(math.log(10000.0) / self.d_model))
|
||||||
|
|
||||||
|
self.pe[:, 0::2] = paddle.sin(position * div_term)
|
||||||
|
self.pe[:, 1::2] = paddle.cos(position * div_term)
|
||||||
|
self.pe = self.pe.unsqueeze(0) #[1, T, D]
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor,
|
||||||
|
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Add positional encoding.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): Input. Its shape is (batch, time, ...)
|
||||||
|
offset (int): position offset
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||||
|
paddle.Tensor: for compatibility to RelPositionalEncoding
|
||||||
|
"""
|
||||||
|
T = paddle.shape(x)[1]
|
||||||
|
assert offset + T < self.max_len
|
||||||
|
#assert offset + x.size(1) < self.max_len
|
||||||
|
#self.pe = self.pe.to(x.device)
|
||||||
|
#pos_emb = self.pe[:, offset:offset + x.size(1)]
|
||||||
|
pos_emb = self.pe[:, offset:offset + T]
|
||||||
|
x = x * self.xscale + pos_emb
|
||||||
|
return self.dropout(x), self.dropout(pos_emb)
|
||||||
|
|
||||||
|
def position_encoding(self, offset: int, size: int) -> paddle.Tensor:
|
||||||
|
""" For getting encoding in a streaming fashion
|
||||||
|
Attention!!!!!
|
||||||
|
we apply dropout only once at the whole utterance level in a none
|
||||||
|
streaming way, but will call this function several times with
|
||||||
|
increasing input size in a streaming scenario, so the dropout will
|
||||||
|
be applied several times.
|
||||||
|
Args:
|
||||||
|
offset (int): start offset
|
||||||
|
size (int): requried size of position encoding
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: Corresponding encoding
|
||||||
|
"""
|
||||||
|
assert offset + size < self.max_len
|
||||||
|
return self.dropout(self.pe[:, offset:offset + size])
|
||||||
|
|
||||||
|
|
||||||
|
class RelPositionalEncoding(PositionalEncoding):
|
||||||
|
"""Relative positional encoding module.
|
||||||
|
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
d_model (int): Embedding dimension.
|
||||||
|
dropout_rate (float): Dropout rate.
|
||||||
|
max_len (int, optional): [Maximum input length.]. Defaults to 5000.
|
||||||
|
"""
|
||||||
|
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
||||||
|
|
||||||
|
def forward(self, x: paddle.Tensor,
|
||||||
|
offset: int=0) -> Tuple[paddle.Tensor, paddle.Tensor]:
|
||||||
|
"""Compute positional encoding.
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): Input tensor (batch, time, `*`).
|
||||||
|
Returns:
|
||||||
|
paddle.Tensor: Encoded tensor (batch, time, `*`).
|
||||||
|
paddle.Tensor: Positional embedding tensor (1, time, `*`).
|
||||||
|
"""
|
||||||
|
T = paddle.shape()[1]
|
||||||
|
assert offset + T < self.max_len
|
||||||
|
#assert offset + x.size(1) < self.max_len
|
||||||
|
#self.pe = self.pe.to(x.device)
|
||||||
|
x = x * self.xscale
|
||||||
|
#pos_emb = self.pe[:, offset:offset + x.size(1)]
|
||||||
|
pos_emb = self.pe[:, offset:offset + T]
|
||||||
|
return self.dropout(x), self.dropout(pos_emb)
|
@ -1,31 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
# download language model
|
|
||||||
bash local/download_lm_ch.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# download well-trained model
|
|
||||||
bash local/download_model.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# infer
|
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python3 -u ${BIN_DIR}/infer.py \
|
|
||||||
--device 'gpu' \
|
|
||||||
--nproc 1 \
|
|
||||||
--config conf/deepspeech2.yaml \
|
|
||||||
--checkpoint_path data/pretrain/params.pdparams \
|
|
||||||
--opts data.mean_std_filepath data/pretrain/mean_std.npz \
|
|
||||||
--opts data.vocab_filepath data/pretrain/vocab.txt
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in inference!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
exit 0
|
|
@ -1,31 +0,0 @@
|
|||||||
#! /usr/bin/env bash
|
|
||||||
|
|
||||||
# download language model
|
|
||||||
bash local/download_lm_ch.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# download well-trained model
|
|
||||||
bash local/download_model.sh
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# evaluate model
|
|
||||||
CUDA_VISIBLE_DEVICES=0 \
|
|
||||||
python3 -u ${BIN_DIR}/test.py \
|
|
||||||
--device 'gpu' \
|
|
||||||
--nproc 1 \
|
|
||||||
--config conf/deepspeech2.yaml \
|
|
||||||
--checkpoint_path data/pretrain/params.pdparams \
|
|
||||||
--opts data.mean_std_filepath data/pretrain/mean_std.npz \
|
|
||||||
--opts data.vocab_filepath data/pretrain/vocab.txt
|
|
||||||
|
|
||||||
if [ $? -ne 0 ]; then
|
|
||||||
echo "Failed in evaluation!"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
exit 0
|
|
@ -1,7 +1,7 @@
|
|||||||
# LibriSpeech
|
# LibriSpeech
|
||||||
|
|
||||||
## CTC
|
## CTC
|
||||||
| Model | Config | Test set | CER |
|
| Model | Config | Test set | WER |
|
||||||
| --- | --- | --- | --- |
|
| --- | --- | --- | --- |
|
||||||
| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 |
|
| DeepSpeech2 | conf/deepspeech2.yaml | test-clean | 0.073973 |
|
||||||
| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 |
|
| DeepSpeech2 | release 1.8.5 | test-clean | 0.074939 |
|
||||||
|
@ -0,0 +1,20 @@
|
|||||||
|
#! /usr/bin/env bash
|
||||||
|
|
||||||
|
if [ $# != 2 ];then
|
||||||
|
echo "usage: export ckpt_path jit_model_path"
|
||||||
|
exit -1
|
||||||
|
fi
|
||||||
|
|
||||||
|
python3 -u ${BIN_DIR}/export.py \
|
||||||
|
--config conf/deepspeech2.yaml \
|
||||||
|
--checkpoint_path ${1} \
|
||||||
|
--export_path ${2}
|
||||||
|
|
||||||
|
|
||||||
|
if [ $? -ne 0 ]; then
|
||||||
|
echo "Failed in evaluation!"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
||||||
|
exit 0
|
Loading…
Reference in new issue