Currently I am working on implementing source separation algorithm between ECG and EMG signals to better evaluate the activity of a muscle. I have try different approaches and the one below is the closet I have gotten to a successful one.
I was wondering if anyone has a better idea on how to process this data?
the data:
- EMG bipolar recording
- fs 12500
- muscle: diaphragm
My goal is to first get a basic hard approach to feed later the output to a KAN NN.
Here is the most recent approach I have attempted ->
I am currently stuck on the algorithm approach to source separate the ECG signal from the EMG. The EMG signal is contaminated by 60Hz and its harmonics.
- I first pass the signal through the basic filters use in this case (bandpass, notch)
- then compute the continuous wavelet transform (cwt) using the https://github.com/OverLordGoldDragon/ssqueezepy
- then I apply a Bayesian threshold using the abs values of the Wx matrix and get my “clean” EMG signal.
After that I have been trying to apply the idea available here: https://www.sciencedirect.com/science/article/pii/S1746809421004584#f0015
Basically:
- low pass filter the base signal between (4-50Hz)
- Normalize the signal
- get the peaks
- and build a window where the ECG signal is located with 0.3s at each side -> [peak-0.3s, peak+0.3s]
Then using that location, grab the raw signal and extract the template using the windows, calculating the mean of the “ECG” signal and then substract the template from the EMG signal.
"""
Another method to extract the linear envelop
Overall method:
"""
# %%
# load the libraries required to run this script
# filesystem paths
from pathlib import Path
# Plotting
import matplotlib.pyplot as plt
# Data loading and manipulation
from neo.io import CedIO
import numpy as np
from scipy.signal import filtfilt, butter, iirnotch, sosfilt, find_peaks
from ssqueezepy.experimental import scale_to_freq
from ssqueezepy import Wavelet, cwt, imshow, icwt
# %%
# Envelop Functions
# moving RMS
def moving_window(tsx, window_size, step=1):
"""
Returns a generator that yields a moving window of size `window_size` over `tsx` with a step size of `step`.
Args:
tsx (list): The input tsx.
window_size (int): The size of the moving window.
step (int, optional): The step size. Defaults to 1.
Yields:
list: A moving window of size `window_size` over `tsx`.
"""
for i in range(0, len(tsx) - window_size + 1, step):
yield tsx[i : i + window_size]
def rolling_mean(x: np.ndarray, N: int):
"""
Calculate the moving window mean
Args:
x (np.ndarray): Input signal
N (int): Window size
Returns:
np.ndarray: Moving window mean values
"""
cumsum = np.cumsum(np.insert(x, 0, 0))
return (cumsum[N:] - cumsum[:-N]) / float(N)
def rolling_rms(x: np.ndarray, N: int, skip: int = 1):
"""
Calculate the moving window RMS of a signal.
Parameters:
x (numpy array): Input signal
N (int): Window size
skip (int, optional): Number of samples to skip between RMS calculations. Defaults to 1.
Returns:
numpy array: Moving window RMS values
"""
xc = np.cumsum(np.insert(x, 0, 0))
rms_values = np.sqrt((xc[N:] - xc[:-N]) / N)
# Handle edge cases by padding with NaNs
rms_values = np.pad(
rms_values, (N // 2, N // 2), mode="constant", constant_values=np.nan
)
# Return RMS values with the desired skip
return rms_values[::skip]
def butter_lowpass_filter(tsx, cutOff, fs_=3, order=4):
"""
Apply a lowpass filter to the time series.
Args:
tsx (array-like): Input time series tsx.
cutOff (float): Cutoff frequency (Hz).
fs_ (int, optional): Sampling frequency (Hz). Defaults to 3.
order (int, optional): Filter order. Defaults to 4.
Returns:
array-like: Filtered time series tsx.
"""
nyq = 0.5 * fs_
normal_cutoff = cutOff / nyq
b, a = butter(order, normal_cutoff, btype="low", analog=False)
return filtfilt(b, a, tsx)
def butter_bandpass(lowcut, highcut, fs, order=5):
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
sos = butter(order, [low, high], analog=False, btype="band", output="sos")
return sos
def butter_bandpass_filter(tsx, lowcut, highcut, fs, order=5):
sos = butter_bandpass(lowcut, highcut, fs, order=order)
y = sosfilt(sos, tsx)
return y
def notch_filter(tsx, f0=60, Q=30, fs_: float = 3):
"""
Apply a notch filter to the time series.
Args:
tsx (array-like): Input time series tsx.
freq (float): Frequency to remove (Hz).
fs (int, optional): Sampling frequency (Hz). Defaults to 3.
order (int, optional): Filter order. Defaults to 4.
Returns:
array-like: Filtered time series tsx.
"""
b, a = iirnotch(f0, Q, fs_)
return filtfilt(b, a, tsx)
def ffc_filter(
tsx: np.ndarray, alpha: float = -1.0, cutoff: float = 60.0, fs_: float = 24000.0
) -> np.ndarray:
"""
Implements the FFC filter as described in [1].
Apply the FFC filter to the input signal.
y(k) = x(k) + (alpha * x)(k-N)
x = input signal
N = delay expressed in number of samples fs/f(w) where w is the target frequency to remove
alpha = regulates aspects of the filter behavior (-1)
[1] D. Esposito, J. Centracchio, P. Bifulco, and E. Andreozzi, “A smart approach to EMG envelope extraction and powerful denoising for human–machine interfaces,” Sci. Rep., vol. 13, no. 1, p. 7768, 2023, doi: 10.1038/s41598-023-33319-4.
Parameters:
tsx (np.ndarray): Input signal.
alpha (float, optional): Filter behavior regulator. Defaults to -1.0.
cutoff (float, optional): Target frequency to remove. Defaults to 60.0.
fs (float, optional): Sampling frequency. Defaults to 24000.0.
Returns:
np.ndarray: Filtered signal.
"""
assert tsx.ndim == 1, "Array must be one-dimensional"
fs_ffc = int(fs_ / cutoff) # delay expressed in number of samples
return tsx + alpha * np.roll(tsx, -fs_ffc) # apply the FFC filter
# %%
# define home (Drive were data is located)
home = Path.home()
# define the path to the data file
data_path = r"DocumentsdataMethods-Paper_EMGtest.smrx"
# build the path to the data file
working_path = home.joinpath(data_path)
print(working_path)
# %%
# load the data using CedIO
reader = CedIO(filename=working_path)
data = reader.read()
data_block = data[0].segments[0]
# %%
ch01_emg = data_block.analogsignals[0]
fs = ch01_emg.sampling_rate
# %%
width = 10 # Width in inches
height = 4 # Height in inches
# Create a sample plot with the specified figure size
fig, ax = plt.subplots(figsize=(width, height))
ax.plot(ch01_emg.times, ch01_emg.magnitude)
# %%
# filter
start = int(fs * 75)
end = int(fs * 80)
fs = float(ch01_emg.sampling_rate)
ch01_emg_bp = butter_bandpass_filter(
ch01_emg.magnitude.squeeze(), 0.1, 1000, fs, order=5
)
ch01_emg_notch = notch_filter(tsx=ch01_emg_bp, fs_=fs)
ch01_emg_ffc = ffc_filter(ch01_emg_notch, fs_=fs)
plt.plot(ch01_emg_ffc[start:end]) # plot just a window of the dataset
# %%
# extract the envelop using three methods
ch01_emg_p = np.power(ch01_emg_ffc, 2)
N = int(0.051 * fs)
y_ma = rolling_mean(ch01_emg_p, N) * 10 # moving average filter
y_rms = rolling_rms(ch01_emg_p, N)
y_lp = butter_lowpass_filter(ch01_emg_p, 4, fs) * 10
# %%
# plot the envelop of the signals envelop
plt.plot(ch01_emg_bp[start:end])
plt.plot(y_ma[start:end] * 10) # Moving averager
plt.plot(y_rms[start:end] * 10) # moving RMS
plt.plot(y_lp[start:end] * 10) # low pass filter
plt.legend(["original", "moving average", "moving rms", "low-pass filter"])
# %%
# ======================extracting the ECG=================================
# grab a segment of the data to work with
start = int(fs * 75)
end = int(fs * 80)
x_emg = ch01_emg_notch[start:end] # use the notch filter data
# %%# Extract the wavelet transform
N = len(x_emg)
t = np.linspace(0, N / fs, N)
wavelet = Wavelet()
Wx, scales = cwt(x_emg, wavelet)
freqs_cwt = scale_to_freq(scales, wavelet, len(x_emg), fs=fs)
ikw = dict(abs=1, xticks=t, xlabel="Time [sec]", ylabel="Frequency [Hz]")
imshow(Wx, **ikw, yticks=freqs_cwt)
plt.plot(x_emg)
# %%
## Other methods
# # Universal threshold threshold = np.sqrt(2 * np.log(magnitude.size))
# thresholded_Wx = np.where(magnitude > threshold, Wx - threshold, 0) # soft threshold
##
# Filter the Wx values
magnitude = np.abs(Wx)
# Bayesian threshold
threshold = np.sqrt(2 * np.log(magnitude.shape[0])) * np.std(magnitude)
thresholded_Wx = np.where(magnitude > threshold, Wx, 0) # hard threshold
# %%
# plot the Wx matrix
ikw = dict(abs=1, xticks=t, xlabel="Time [sec]", ylabel="Frequency [Hz]")
imshow(thresholded_Wx, **ikw, yticks=freqs_cwt)
# %%
# reconstruct the signal from the threshold Wx matrix
x_emg_thr = icwt(thresholded_Wx, wavelet)
plt.plot(x_emg_thr) # plot the new signal
# %% # find ecg peaks
ecg_signal = butter_bandpass_filter(x_emg_thr, 0.1, 150, fs, order=5)
plt.plot(ecg_signal)
# normalize the ecg signal
signal_normalized = (ecg_signal - np.mean(ecg_signal)) / np.std(ecg_signal)
plt.plot(signal_normalized)
# %%
# Find the peaks of the ecg signal
signal_std = np.std(signal_normalized)
hrt_max_freq = int(fs * 0.15)
peaks, _ = find_peaks(
signal_normalized, distance=hrt_max_freq, height=signal_std
) # Adjust the distance value as neededv
plt.plot(signal_normalized)
plt.plot(peaks, signal_normalized[peaks], "x")
plt.show()
# %%
def get_edges(arr, window_size):
return [[max(0, idx - window_size), idx + window_size] for idx in arr]
# %%
# find the edges of the ECG signal
window_size = int(0.12 * fs)
edges = get_edges(peaks, window_size)
edges_array = np.array(edges)
diff = edges_array[:, 1] - edges_array[:, 0]
print(diff) # print the length of the edges
# %%
# extract the template
template = list()
for idx in edges[:-1]:
template.append(ecg_signal[idx[0] : idx[1]])
template = np.array(template)
plt.plot(np.mean(template, axis=0))
# %%
# normalize the template array
template_normalized = (template - np.mean(ecg_signal)) / np.std(ecg_signal)
# %%
for template_s in template_normalized:
plt.plot(template_s, alpha=0.5)
plt.plot(np.median(template_normalized, axis=0), color="r")
# %%
template_final = np.median(template_normalized, axis=0)
# %%
template_r = np.repeat(
template_final[:, np.newaxis], template_normalized.shape[0], axis=1
)
# %%
template_scaled = (template_r * np.std(template, axis=1)) + np.mean(template, axis=1)
# %%
plt.plot(template_scaled)
# %%
ecg_filter = x_emg_thr.copy()
for i, idx in enumerate(edges[:-1]):
ecg_filter[idx[0] : idx[1]] -= -template_scaled[:, i]
plt.plot(x_emg_thr)
plt.plot(ecg_filter)
# %%
# plot individual template
idx = edges[5]
plt.plot(-template_scaled[::-1, i])
plt.plot(x_emg_thr[idx[0] : idx[1]], alpha=0.5)
plt.plot(x_emg_thr[idx[0] : idx[1]] - template_scaled[:, 16], alpha=0.5)
# %%
# plot the wavlet filter data
wavelet = Wavelet()
Wx_, scales_ = cwt(x_emg_thr, wavelet)
freqs_cwt = scale_to_freq(scales_, wavelet, len(x_emg_thr), fs=fs)
ikw = dict(abs=1, xticks=t, xlabel="Time [sec]", ylabel="Frequency [Hz]")
imshow(Wx_, **ikw, yticks=freqs_cwt)
# %%
# plot the ECG filter data
wavelet = Wavelet()
Wx_filter, scales_filter = cwt(ecg_filter, wavelet)
freqs_cwt = scale_to_freq(scales_filter, wavelet, len(ecg_filter), fs=fs)
ikw = dict(abs=1, xticks=t, xlabel="Time [sec]", ylabel="Frequency [Hz]")
imshow(Wx_filter, **ikw, yticks=freqs_cwt)
# %%