From af73cc42b8f297908f20e4097f01d2cc6f8039eb Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Mon, 30 Dec 2024 11:57:11 +0000 Subject: [PATCH] fix .to() --- audio/audiotools/core/audio_signal.py | 30 ++++++++++++++------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/audio/audiotools/core/audio_signal.py b/audio/audiotools/core/audio_signal.py index 80a4130d6..acb0bcafe 100644 --- a/audio/audiotools/core/audio_signal.py +++ b/audio/audiotools/core/audio_signal.py @@ -764,6 +764,19 @@ class AudioSignal( self.sample_rate = sample_rate return self + @staticmethod + def move_to_device(data, device): + if device is None or device == "": + return data + elif device == 'cpu': + return paddle.to_tensor(data, place=paddle.CPUPlace()) + elif device in ('gpu', 'cuda'): + return paddle.to_tensor(data, place=paddle.CUDAPlace()) + else: + device = device.replace("cuda", + "gpu") if "cuda" in device else device + return data.to(device) + # Tensor operations def to(self, device: str): """✅Moves all tensors contained in signal to the specified device. @@ -780,22 +793,11 @@ class AudioSignal( AudioSignal with all tensors moved to specified device. """ if self._loudness is not None: - self._loudness = self._loudness.to(device) + self._loudness = self.move_to_device(self._loudness, device) if self.stft_data is not None: - self.stft_data = self.stft_data.to(device) + self.stft_data = self.move_to_device(self.stft_data, device) if self.audio_data is not None: - if device is None or "" == device: - return self - elif 'cpu' == device: - self.audio_data = paddle.to_tensor( - self.audio_data, place=paddle.CPUPlace()) - elif 'gpu' == device or 'cuda' == device: - self.audio_data = paddle.to_tensor( - self.audio_data, place=paddle.CUDAPlace()) - else: - device = device.replace("cuda", - "gpu") if "cuda" in device else device - self.audio_data = self.audio_data.to(device) + self.audio_data = self.move_to_device(self.audio_data, device) return self def float(self):