diff --git a/audio/audiotools/core/audio_signal.py b/audio/audiotools/core/audio_signal.py index 80a4130d6..acb0bcafe 100644 --- a/audio/audiotools/core/audio_signal.py +++ b/audio/audiotools/core/audio_signal.py @@ -764,6 +764,19 @@ class AudioSignal( self.sample_rate = sample_rate return self + @staticmethod + def move_to_device(data, device): + if device is None or device == "": + return data + elif device == 'cpu': + return paddle.to_tensor(data, place=paddle.CPUPlace()) + elif device in ('gpu', 'cuda'): + return paddle.to_tensor(data, place=paddle.CUDAPlace()) + else: + device = device.replace("cuda", + "gpu") if "cuda" in device else device + return data.to(device) + # Tensor operations def to(self, device: str): """✅Moves all tensors contained in signal to the specified device. @@ -780,22 +793,11 @@ class AudioSignal( AudioSignal with all tensors moved to specified device. """ if self._loudness is not None: - self._loudness = self._loudness.to(device) + self._loudness = self.move_to_device(self._loudness, device) if self.stft_data is not None: - self.stft_data = self.stft_data.to(device) + self.stft_data = self.move_to_device(self.stft_data, device) if self.audio_data is not None: - if device is None or "" == device: - return self - elif 'cpu' == device: - self.audio_data = paddle.to_tensor( - self.audio_data, place=paddle.CPUPlace()) - elif 'gpu' == device or 'cuda' == device: - self.audio_data = paddle.to_tensor( - self.audio_data, place=paddle.CUDAPlace()) - else: - device = device.replace("cuda", - "gpu") if "cuda" in device else device - self.audio_data = self.audio_data.to(device) + self.audio_data = self.move_to_device(self.audio_data, device) return self def float(self):