I have implemented STFT (Short-Time Fourier Transform) and inverse STFT in both TensorFlow and PyTorch. While the PyTorch version successfully restores the original signal, the TensorFlow version does not.
Here are my implementations for tensorflow:
def spectro(
signal: tf.Tensor, n_fft: int = 4096, hop_length: int = None
) -> tf.Tensor:
if hop_length is None:
hop_length = n_fft // 4
assert hop_length == n_fft // 4, "hop_length should be n_fft // 4"
frames = int(math.ceil(signal.shape[-1] / hop_length))
pad_left = (hop_length // 2 * 3)
pad_right = pad_left + frames * hop_length - signal.shape[-1]
pad_second = (n_fft // 2)
# shape (batch, channel, signal) -> (batch, channel, signal + pad_left + pad_right)
padded_signal = STFTUtils.pad_1d(signal, (pad_left, pad_right), "REFLECT")
padded_signal = STFTUtils.pad_1d(padded_signal, (pad_second, pad_second), "REFLECT")
window_fn = tf.signal.hann_window
*other, length = padded_signal.shape
padded_signal = tf.reshape(padded_signal,[-1,length])
stfts = tf.signal.stft(
padded_signal,
frame_length=n_fft,
frame_step=hop_length,
fft_length=n_fft,
window_fn=window_fn,
) # shape (batch, channel, fft_bins, fft_length // 2 + 1)
_,frames_sec, frequency = stfts.shape
stfts = stfts[..., : , :-1]
stfts = stfts[..., 2 : 2 + frames, : ]
return stfts # shape (batch, channel, fft_bins - 4, fft_length // 2)
def inverse_spectro(spectrograms, hop_length=None, signal_length=None) -> tf.Tensor:
spectrograms = tf.pad(spectrograms, [[0, 0], [0, 0], [2,2], [0,1]])
*other_dims, num_freqs = spectrograms.shape
n_fft = 2 * num_freqs - 2
if hop_length is None:
hop_length = n_fft // 4
window_fn = tf.signal.hann_window
signal = tf.signal.inverse_stft(
spectrograms,
frame_length=n_fft,
frame_step=hop_length,
fft_length=n_fft,
window_fn=window_fn
) # shape (batch, channel, signal_length + padding)
padding = (hop_length // 2 * 3) + (n_fft // 2) # Amount of padding that was applied
signal = signal[..., padding : signal_length + padding]
return signal # shape (batch, channel, signal_length)
Here are my implementations for torch:
def spectro(x, n_fft=512, hop_length=None, pad=0):
*other, length = x.shape
x = x.reshape(-1, length)
is_mps = x.device.type == 'mps'
if is_mps:
x = x.cpu()
import torch
#print("before stft", x, x.shape, torch.sum(torch.square(x)))
z = th.stft(x,
n_fft * (1 + pad),
hop_length or n_fft // 4,
window=th.hann_window(n_fft).to(x),
win_length=n_fft,
normalized=False,
center=True,
return_complex=True,
pad_mode='reflect')
_, freqs, frame = z.shape
return z.view(*other, freqs, frame)
def ispectro(z, hop_length=None, length=None, pad=0):
*other, freqs, frames = z.shape
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
is_mps = z.device.type == 'mps'
if is_mps:
z = z.cpu()
x = th.istft(z,
n_fft,
hop_length,
window=th.hann_window(win_length).to(z.real),
win_length=win_length,
normalized=False,
length=length,
center=True)
_, length = x.shape
return x.view(*other, length)
In the PyTorch code, I can restore the original signal without issues. However, the TensorFlow version does not restore the original signal properly. What could be the probable issues in my TensorFlow implementation?
Thank you for your help!