parent
6ec6921255
commit
92d1d08b9a
@ -0,0 +1,235 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains the text featurizer class."""
|
||||
from pprint import pformat
|
||||
from typing import Union
|
||||
|
||||
import sentencepiece as spm
|
||||
|
||||
from .utility import BLANK
|
||||
from .utility import EOS
|
||||
from .utility import load_dict
|
||||
from .utility import MASKCTC
|
||||
from .utility import SOS
|
||||
from .utility import SPACE
|
||||
from .utility import UNK
|
||||
from ..utils.log import Logger
|
||||
|
||||
logger = Logger(__name__)
|
||||
|
||||
__all__ = ["TextFeaturizer"]
|
||||
|
||||
|
||||
class TextFeaturizer():
|
||||
def __init__(self, unit_type, vocab, spm_model_prefix=None, maskctc=False):
|
||||
"""Text featurizer, for processing or extracting features from text.
|
||||
|
||||
Currently, it supports char/word/sentence-piece level tokenizing and conversion into
|
||||
a list of token indices. Note that the token indexing order follows the
|
||||
given vocabulary file.
|
||||
|
||||
Args:
|
||||
unit_type (str): unit type, e.g. char, word, spm
|
||||
vocab Option[str, list]: Filepath to load vocabulary for token indices conversion, or vocab list.
|
||||
spm_model_prefix (str, optional): spm model prefix. Defaults to None.
|
||||
"""
|
||||
assert unit_type in ('char', 'spm', 'word')
|
||||
self.unit_type = unit_type
|
||||
self.unk = UNK
|
||||
self.maskctc = maskctc
|
||||
|
||||
if vocab:
|
||||
self.vocab_dict, self._id2token, self.vocab_list, self.unk_id, self.eos_id, self.blank_id = self._load_vocabulary_from_file(
|
||||
vocab, maskctc)
|
||||
self.vocab_size = len(self.vocab_list)
|
||||
else:
|
||||
logger.warning("TextFeaturizer: not have vocab file or vocab list.")
|
||||
|
||||
if unit_type == 'spm':
|
||||
spm_model = spm_model_prefix + '.model'
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.Load(spm_model)
|
||||
|
||||
def tokenize(self, text, replace_space=True):
|
||||
if self.unit_type == 'char':
|
||||
tokens = self.char_tokenize(text, replace_space)
|
||||
elif self.unit_type == 'word':
|
||||
tokens = self.word_tokenize(text)
|
||||
else: # spm
|
||||
tokens = self.spm_tokenize(text)
|
||||
return tokens
|
||||
|
||||
def detokenize(self, tokens):
|
||||
if self.unit_type == 'char':
|
||||
text = self.char_detokenize(tokens)
|
||||
elif self.unit_type == 'word':
|
||||
text = self.word_detokenize(tokens)
|
||||
else: # spm
|
||||
text = self.spm_detokenize(tokens)
|
||||
return text
|
||||
|
||||
def featurize(self, text):
|
||||
"""Convert text string to a list of token indices.
|
||||
|
||||
Args:
|
||||
text (str): Text to process.
|
||||
|
||||
Returns:
|
||||
List[int]: List of token indices.
|
||||
"""
|
||||
tokens = self.tokenize(text)
|
||||
ids = []
|
||||
for token in tokens:
|
||||
if token not in self.vocab_dict:
|
||||
logger.debug(f"Text Token: {token} -> {self.unk}")
|
||||
token = self.unk
|
||||
ids.append(self.vocab_dict[token])
|
||||
return ids
|
||||
|
||||
def defeaturize(self, idxs):
|
||||
"""Convert a list of token indices to text string,
|
||||
ignore index after eos_id.
|
||||
|
||||
Args:
|
||||
idxs (List[int]): List of token indices.
|
||||
|
||||
Returns:
|
||||
str: Text.
|
||||
"""
|
||||
tokens = []
|
||||
for idx in idxs:
|
||||
if idx == self.eos_id:
|
||||
break
|
||||
tokens.append(self._id2token[idx])
|
||||
text = self.detokenize(tokens)
|
||||
return text
|
||||
|
||||
def char_tokenize(self, text, replace_space=True):
|
||||
"""Character tokenizer.
|
||||
|
||||
Args:
|
||||
text (str): text string.
|
||||
replace_space (bool): False only used by build_vocab.py.
|
||||
|
||||
Returns:
|
||||
List[str]: tokens.
|
||||
"""
|
||||
text = text.strip()
|
||||
if replace_space:
|
||||
text_list = [SPACE if item == " " else item for item in list(text)]
|
||||
else:
|
||||
text_list = list(text)
|
||||
return text_list
|
||||
|
||||
def char_detokenize(self, tokens):
|
||||
"""Character detokenizer.
|
||||
|
||||
Args:
|
||||
tokens (List[str]): tokens.
|
||||
|
||||
Returns:
|
||||
str: text string.
|
||||
"""
|
||||
tokens = [t.replace(SPACE, " ") for t in tokens]
|
||||
return "".join(tokens)
|
||||
|
||||
def word_tokenize(self, text):
|
||||
"""Word tokenizer, separate by <space>."""
|
||||
return text.strip().split()
|
||||
|
||||
def word_detokenize(self, tokens):
|
||||
"""Word detokenizer, separate by <space>."""
|
||||
return " ".join(tokens)
|
||||
|
||||
def spm_tokenize(self, text):
|
||||
"""spm tokenize.
|
||||
|
||||
Args:
|
||||
text (str): text string.
|
||||
|
||||
Returns:
|
||||
List[str]: sentence pieces str code
|
||||
"""
|
||||
stats = {"num_empty": 0, "num_filtered": 0}
|
||||
|
||||
def valid(line):
|
||||
return True
|
||||
|
||||
def encode(l):
|
||||
return self.sp.EncodeAsPieces(l)
|
||||
|
||||
def encode_line(line):
|
||||
line = line.strip()
|
||||
if len(line) > 0:
|
||||
line = encode(line)
|
||||
if valid(line):
|
||||
return line
|
||||
else:
|
||||
stats["num_filtered"] += 1
|
||||
else:
|
||||
stats["num_empty"] += 1
|
||||
return None
|
||||
|
||||
enc_line = encode_line(text)
|
||||
return enc_line
|
||||
|
||||
def spm_detokenize(self, tokens, input_format='piece'):
|
||||
"""spm detokenize.
|
||||
|
||||
Args:
|
||||
ids (List[str]): tokens.
|
||||
|
||||
Returns:
|
||||
str: text
|
||||
"""
|
||||
if input_format == "piece":
|
||||
|
||||
def decode(l):
|
||||
return "".join(self.sp.DecodePieces(l))
|
||||
elif input_format == "id":
|
||||
|
||||
def decode(l):
|
||||
return "".join(self.sp.DecodeIds(l))
|
||||
|
||||
return decode(tokens)
|
||||
|
||||
def _load_vocabulary_from_file(self, vocab: Union[str, list],
|
||||
maskctc: bool):
|
||||
"""Load vocabulary from file."""
|
||||
if isinstance(vocab, list):
|
||||
vocab_list = vocab
|
||||
else:
|
||||
vocab_list = load_dict(vocab, maskctc)
|
||||
assert vocab_list is not None
|
||||
logger.debug(f"Vocab: {pformat(vocab_list)}")
|
||||
|
||||
id2token = dict(
|
||||
[(idx, token) for (idx, token) in enumerate(vocab_list)])
|
||||
token2id = dict(
|
||||
[(token, idx) for (idx, token) in enumerate(vocab_list)])
|
||||
|
||||
blank_id = vocab_list.index(BLANK) if BLANK in vocab_list else -1
|
||||
maskctc_id = vocab_list.index(MASKCTC) if MASKCTC in vocab_list else -1
|
||||
unk_id = vocab_list.index(UNK) if UNK in vocab_list else -1
|
||||
eos_id = vocab_list.index(EOS) if EOS in vocab_list else -1
|
||||
sos_id = vocab_list.index(SOS) if SOS in vocab_list else -1
|
||||
space_id = vocab_list.index(SPACE) if SPACE in vocab_list else -1
|
||||
|
||||
logger.info(f"BLANK id: {blank_id}")
|
||||
logger.info(f"UNK id: {unk_id}")
|
||||
logger.info(f"EOS id: {eos_id}")
|
||||
logger.info(f"SOS id: {sos_id}")
|
||||
logger.info(f"SPACE id: {space_id}")
|
||||
logger.info(f"MASKCTC id: {maskctc_id}")
|
||||
return token2id, id2token, vocab_list, unk_id, eos_id, blank_id
|
@ -0,0 +1,393 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains data helper functions."""
|
||||
import json
|
||||
import math
|
||||
import tarfile
|
||||
from collections import namedtuple
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Text
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
|
||||
from paddlespeech.s2t.utils.log import Log
|
||||
|
||||
logger = Log(__name__).getlog()
|
||||
|
||||
__all__ = [
|
||||
"load_dict", "load_cmvn", "read_manifest", "rms_to_db", "rms_to_dbfs",
|
||||
"max_dbfs", "mean_dbfs", "gain_db_to_ratio", "normalize_audio", "SOS",
|
||||
"EOS", "UNK", "BLANK", "MASKCTC", "SPACE", "convert_samples_to_float32",
|
||||
"convert_samples_from_float32"
|
||||
]
|
||||
|
||||
IGNORE_ID = -1
|
||||
# `sos` and `eos` using same token
|
||||
SOS = "<eos>"
|
||||
EOS = SOS
|
||||
UNK = "<unk>"
|
||||
BLANK = "<blank>"
|
||||
MASKCTC = "<mask>"
|
||||
SPACE = "<space>"
|
||||
|
||||
|
||||
def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
|
||||
if dict_path is None:
|
||||
return None
|
||||
|
||||
with open(dict_path, "r") as f:
|
||||
dictionary = f.readlines()
|
||||
# first token is `<blank>`
|
||||
# multi line: `<blank> 0\n`
|
||||
# one line: `<blank>`
|
||||
# space is relpace with <space>
|
||||
char_list = [entry[:-1].split(" ")[0] for entry in dictionary]
|
||||
if BLANK not in char_list:
|
||||
char_list.insert(0, BLANK)
|
||||
if EOS not in char_list:
|
||||
char_list.append(EOS)
|
||||
# for non-autoregressive maskctc model
|
||||
if maskctc and MASKCTC not in char_list:
|
||||
char_list.append(MASKCTC)
|
||||
return char_list
|
||||
|
||||
|
||||
def read_manifest(
|
||||
manifest_path,
|
||||
max_input_len=float('inf'),
|
||||
min_input_len=0.0,
|
||||
max_output_len=float('inf'),
|
||||
min_output_len=0.0,
|
||||
max_output_input_ratio=float('inf'),
|
||||
min_output_input_ratio=0.0, ):
|
||||
"""Load and parse manifest file.
|
||||
|
||||
Args:
|
||||
manifest_path ([type]): Manifest file to load and parse.
|
||||
max_input_len ([type], optional): maximum output seq length,
|
||||
in seconds for raw wav, in frame numbers for feature data.
|
||||
Defaults to float('inf').
|
||||
min_input_len (float, optional): minimum input seq length,
|
||||
in seconds for raw wav, in frame numbers for feature data.
|
||||
Defaults to 0.0.
|
||||
max_output_len (float, optional): maximum input seq length,
|
||||
in modeling units. Defaults to 500.0.
|
||||
min_output_len (float, optional): minimum input seq length,
|
||||
in modeling units. Defaults to 0.0.
|
||||
max_output_input_ratio (float, optional):
|
||||
maximum output seq length/output seq length ratio. Defaults to 10.0.
|
||||
min_output_input_ratio (float, optional):
|
||||
minimum output seq length/output seq length ratio. Defaults to 0.05.
|
||||
|
||||
Raises:
|
||||
IOError: If failed to parse the manifest.
|
||||
|
||||
Returns:
|
||||
List[dict]: Manifest parsing results.
|
||||
"""
|
||||
manifest = []
|
||||
with jsonlines.open(manifest_path, 'r') as reader:
|
||||
for json_data in reader:
|
||||
feat_len = json_data["input"][0]["shape"][
|
||||
0] if "input" in json_data and "shape" in json_data["input"][
|
||||
0] else 1.0
|
||||
token_len = json_data["output"][0]["shape"][
|
||||
0] if "output" in json_data and "shape" in json_data["output"][
|
||||
0] else 1.0
|
||||
conditions = [
|
||||
feat_len >= min_input_len,
|
||||
feat_len <= max_input_len,
|
||||
token_len >= min_output_len,
|
||||
token_len <= max_output_len,
|
||||
token_len / feat_len >= min_output_input_ratio,
|
||||
token_len / feat_len <= max_output_input_ratio,
|
||||
]
|
||||
if all(conditions):
|
||||
manifest.append(json_data)
|
||||
return manifest
|
||||
|
||||
|
||||
# Tar File read
|
||||
TarLocalData = namedtuple('TarLocalData', ['tar2info', 'tar2object'])
|
||||
|
||||
|
||||
def parse_tar(file):
|
||||
"""Parse a tar file to get a tarfile object
|
||||
and a map containing tarinfoes
|
||||
"""
|
||||
result = {}
|
||||
f = tarfile.open(file)
|
||||
for tarinfo in f.getmembers():
|
||||
result[tarinfo.name] = tarinfo
|
||||
return f, result
|
||||
|
||||
|
||||
def subfile_from_tar(file, local_data=None):
|
||||
"""Get subfile object from tar.
|
||||
|
||||
tar:tarpath#filename
|
||||
|
||||
It will return a subfile object from tar file
|
||||
and cached tar file info for next reading request.
|
||||
"""
|
||||
tarpath, filename = file.split(':', 1)[1].split('#', 1)
|
||||
|
||||
if local_data is None:
|
||||
local_data = TarLocalData(tar2info={}, tar2object={})
|
||||
|
||||
assert isinstance(local_data, TarLocalData)
|
||||
|
||||
if 'tar2info' not in local_data.__dict__:
|
||||
local_data.tar2info = {}
|
||||
if 'tar2object' not in local_data.__dict__:
|
||||
local_data.tar2object = {}
|
||||
|
||||
if tarpath not in local_data.tar2info:
|
||||
fobj, infos = parse_tar(tarpath)
|
||||
local_data.tar2info[tarpath] = infos
|
||||
local_data.tar2object[tarpath] = fobj
|
||||
else:
|
||||
fobj = local_data.tar2object[tarpath]
|
||||
infos = local_data.tar2info[tarpath]
|
||||
return fobj.extractfile(infos[filename])
|
||||
|
||||
|
||||
def rms_to_db(rms: float):
|
||||
"""Root Mean Square to dB.
|
||||
|
||||
Args:
|
||||
rms ([float]): root mean square
|
||||
|
||||
Returns:
|
||||
float: dB
|
||||
"""
|
||||
return 20.0 * math.log10(max(1e-16, rms))
|
||||
|
||||
|
||||
def rms_to_dbfs(rms: float):
|
||||
"""Root Mean Square to dBFS.
|
||||
https://fireattack.wordpress.com/2017/02/06/replaygain-loudness-normalization-and-applications/
|
||||
Audio is mix of sine wave, so 1 amp sine wave's Full scale is 0.7071, equal to -3.0103dB.
|
||||
|
||||
dB = dBFS + 3.0103
|
||||
dBFS = db - 3.0103
|
||||
e.g. 0 dB = -3.0103 dBFS
|
||||
|
||||
Args:
|
||||
rms ([float]): root mean square
|
||||
|
||||
Returns:
|
||||
float: dBFS
|
||||
"""
|
||||
return rms_to_db(rms) - 3.0103
|
||||
|
||||
|
||||
def max_dbfs(sample_data: np.ndarray):
|
||||
"""Peak dBFS based on the maximum energy sample.
|
||||
|
||||
Args:
|
||||
sample_data ([np.ndarray]): float array, [-1, 1].
|
||||
|
||||
Returns:
|
||||
float: dBFS
|
||||
"""
|
||||
# Peak dBFS based on the maximum energy sample. Will prevent overdrive if used for normalization.
|
||||
return rms_to_dbfs(max(abs(np.min(sample_data)), abs(np.max(sample_data))))
|
||||
|
||||
|
||||
def mean_dbfs(sample_data):
|
||||
"""Peak dBFS based on the RMS energy.
|
||||
|
||||
Args:
|
||||
sample_data ([np.ndarray]): float array, [-1, 1].
|
||||
|
||||
Returns:
|
||||
float: dBFS
|
||||
"""
|
||||
return rms_to_dbfs(
|
||||
math.sqrt(np.mean(np.square(sample_data, dtype=np.float64))))
|
||||
|
||||
|
||||
def gain_db_to_ratio(gain_db: float):
|
||||
"""dB to ratio
|
||||
|
||||
Args:
|
||||
gain_db (float): gain in dB
|
||||
|
||||
Returns:
|
||||
float: scale in amp
|
||||
"""
|
||||
return math.pow(10.0, gain_db / 20.0)
|
||||
|
||||
|
||||
def normalize_audio(sample_data: np.ndarray, dbfs: float=-3.0103):
|
||||
"""Nomalize audio to dBFS.
|
||||
|
||||
Args:
|
||||
sample_data (np.ndarray): input wave samples, [-1, 1].
|
||||
dbfs (float, optional): target dBFS. Defaults to -3.0103.
|
||||
|
||||
Returns:
|
||||
np.ndarray: normalized wave
|
||||
"""
|
||||
return np.maximum(
|
||||
np.minimum(sample_data * gain_db_to_ratio(dbfs - max_dbfs(sample_data)),
|
||||
1.0), -1.0)
|
||||
|
||||
|
||||
def _load_json_cmvn(json_cmvn_file):
|
||||
""" Load the json format cmvn stats file and calculate cmvn
|
||||
|
||||
Args:
|
||||
json_cmvn_file: cmvn stats file in json format
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
with open(json_cmvn_file) as f:
|
||||
cmvn_stats = json.load(f)
|
||||
|
||||
means = cmvn_stats['mean_stat']
|
||||
variance = cmvn_stats['var_stat']
|
||||
count = cmvn_stats['frame_num']
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
|
||||
|
||||
def _load_kaldi_cmvn(kaldi_cmvn_file):
|
||||
""" Load the kaldi format cmvn stats file and calculate cmvn
|
||||
|
||||
Args:
|
||||
kaldi_cmvn_file: kaldi text style global cmvn file, which
|
||||
is generated by:
|
||||
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
|
||||
|
||||
Returns:
|
||||
a numpy array of [means, vars]
|
||||
"""
|
||||
means = []
|
||||
variance = []
|
||||
with open(kaldi_cmvn_file, 'r') as fid:
|
||||
# kaldi binary file start with '\0B'
|
||||
if fid.read(2) == '\0B':
|
||||
logger.error('kaldi cmvn binary file is not supported, please '
|
||||
'recompute it by: compute-cmvn-stats --binary=false '
|
||||
' scp:feats.scp global_cmvn')
|
||||
sys.exit(1)
|
||||
fid.seek(0)
|
||||
arr = fid.read().split()
|
||||
assert (arr[0] == '[')
|
||||
assert (arr[-2] == '0')
|
||||
assert (arr[-1] == ']')
|
||||
feat_dim = int((len(arr) - 2 - 2) / 2)
|
||||
for i in range(1, feat_dim + 1):
|
||||
means.append(float(arr[i]))
|
||||
count = float(arr[feat_dim + 1])
|
||||
for i in range(feat_dim + 2, 2 * feat_dim + 2):
|
||||
variance.append(float(arr[i]))
|
||||
|
||||
for i in range(len(means)):
|
||||
means[i] /= count
|
||||
variance[i] = variance[i] / count - means[i] * means[i]
|
||||
if variance[i] < 1.0e-20:
|
||||
variance[i] = 1.0e-20
|
||||
variance[i] = 1.0 / math.sqrt(variance[i])
|
||||
cmvn = np.array([means, variance])
|
||||
return cmvn
|
||||
|
||||
|
||||
def load_cmvn(cmvn_file: str, filetype: str):
|
||||
"""load cmvn from file.
|
||||
|
||||
Args:
|
||||
cmvn_file (str): cmvn path.
|
||||
filetype (str): file type, optional[npz, json, kaldi].
|
||||
|
||||
Raises:
|
||||
ValueError: file type not support.
|
||||
|
||||
Returns:
|
||||
Tuple[np.ndarray, np.ndarray]: mean, istd
|
||||
"""
|
||||
assert filetype in ['npz', 'json', 'kaldi'], filetype
|
||||
filetype = filetype.lower()
|
||||
if filetype == "json":
|
||||
cmvn = _load_json_cmvn(cmvn_file)
|
||||
elif filetype == "kaldi":
|
||||
cmvn = _load_kaldi_cmvn(cmvn_file)
|
||||
elif filetype == "npz":
|
||||
eps = 1e-14
|
||||
npzfile = np.load(cmvn_file)
|
||||
mean = np.squeeze(npzfile["mean"])
|
||||
std = np.squeeze(npzfile["std"])
|
||||
istd = 1 / (std + eps)
|
||||
cmvn = [mean, istd]
|
||||
else:
|
||||
raise ValueError(f"cmvn file type no support: {filetype}")
|
||||
return cmvn[0], cmvn[1]
|
||||
|
||||
|
||||
def convert_samples_to_float32(samples):
|
||||
"""Convert sample type to float32.
|
||||
|
||||
Audio sample type is usually integer or float-point.
|
||||
Integers will be scaled to [-1, 1] in float32.
|
||||
|
||||
PCM16 -> PCM32
|
||||
"""
|
||||
float32_samples = samples.astype('float32')
|
||||
if samples.dtype in np.sctypes['int']:
|
||||
bits = np.iinfo(samples.dtype).bits
|
||||
float32_samples *= (1. / 2**(bits - 1))
|
||||
elif samples.dtype in np.sctypes['float']:
|
||||
pass
|
||||
else:
|
||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||
return float32_samples
|
||||
|
||||
|
||||
def convert_samples_from_float32(samples, dtype):
|
||||
"""Convert sample type from float32 to dtype.
|
||||
|
||||
Audio sample type is usually integer or float-point. For integer
|
||||
type, float32 will be rescaled from [-1, 1] to the maximum range
|
||||
supported by the integer type.
|
||||
|
||||
PCM32 -> PCM16
|
||||
"""
|
||||
dtype = np.dtype(dtype)
|
||||
output_samples = samples.copy()
|
||||
if dtype in np.sctypes['int']:
|
||||
bits = np.iinfo(dtype).bits
|
||||
output_samples *= (2**(bits - 1) / 1.)
|
||||
min_val = np.iinfo(dtype).min
|
||||
max_val = np.iinfo(dtype).max
|
||||
output_samples[output_samples > max_val] = max_val
|
||||
output_samples[output_samples < min_val] = min_val
|
||||
elif samples.dtype in np.sctypes['float']:
|
||||
min_val = np.finfo(dtype).min
|
||||
max_val = np.finfo(dtype).max
|
||||
output_samples[output_samples > max_val] = max_val
|
||||
output_samples[output_samples < min_val] = min_val
|
||||
else:
|
||||
raise TypeError("Unsupported sample type: %s." % samples.dtype)
|
||||
return output_samples.astype(dtype)
|
Loading…
Reference in new issue