pull/3900/head
drryanhuang 9 months ago
parent 35ee7da5fd
commit af73cc42b8

@ -764,6 +764,19 @@ class AudioSignal(
self.sample_rate = sample_rate
return self
@staticmethod
def move_to_device(data, device):
if device is None or device == "":
return data
elif device == 'cpu':
return paddle.to_tensor(data, place=paddle.CPUPlace())
elif device in ('gpu', 'cuda'):
return paddle.to_tensor(data, place=paddle.CUDAPlace())
else:
device = device.replace("cuda",
"gpu") if "cuda" in device else device
return data.to(device)
# Tensor operations
def to(self, device: str):
"""✅Moves all tensors contained in signal to the specified device.
@ -780,22 +793,11 @@ class AudioSignal(
AudioSignal with all tensors moved to specified device.
"""
if self._loudness is not None:
self._loudness = self._loudness.to(device)
self._loudness = self.move_to_device(self._loudness, device)
if self.stft_data is not None:
self.stft_data = self.stft_data.to(device)
self.stft_data = self.move_to_device(self.stft_data, device)
if self.audio_data is not None:
if device is None or "" == device:
return self
elif 'cpu' == device:
self.audio_data = paddle.to_tensor(
self.audio_data, place=paddle.CPUPlace())
elif 'gpu' == device or 'cuda' == device:
self.audio_data = paddle.to_tensor(
self.audio_data, place=paddle.CUDAPlace())
else:
device = device.replace("cuda",
"gpu") if "cuda" in device else device
self.audio_data = self.audio_data.to(device)
self.audio_data = self.move_to_device(self.audio_data, device)
return self
def float(self):

Loading…
Cancel
Save