I’m trying to implement this 2020 paper: Latent neural source recovery via transcoding of simultaneous EEG-fMRI.
I’ve got a simplified version working, and am now trying to stick precisely to the paper. But I don’t understand how they compute spatial transpose convolutional layers on (batch_size, 15,18,18,n_time_points)
where 15x18x15
is the shape of my volume.
Pytorch’s ConvTranspose3d is happy dealing with a volume, but not a time series of volumes.
Should I just apply this sequentially to each time point?
Here is a diagram of the part of their architecture I am having trouble with:
The only difference I have in my setup is that instead of (63,300)
I have (34, 300)
(these represent 34 electrodes on the scalp, and 300 time points).
Step 1: “assign to volume”, I’ve implemented that and the code is at the bottom. This brings us from (34,300)
to (11,9,5,300)
, where I assume 1@
just means batch size of 1.
Step 2: “strided transpose conv” to go from (11,9,5,300)
to (15,18,15,300)
. Here I’m not sure how to make this work without applying this at each time point individually
In [20]: x = torch.randn(11,9,5,300)
In [21]: nn.ConvTranspose3d(1, 16, (5,10,11))(x[:,:,:,0].unsqueeze(0)).shape
Out[21]: torch.Size([16, 15, 18, 15])
So at least this takes me from 1x11x9x5
to 16x15x18x15
. Does the paper imply to just do this 300 times, once per time point? Also why “strided”?
Step 3: “spatial transpose conv layers”, is this correct:
In [28]: x = torch.randn(16,15,18,15) #pretend output of previous step
In [29]: nn.ConvTranspose3d(16, 16, (1,1,1))(x).shape
Out[29]: torch.Size([16, 15, 18, 15])
Would this be a correct forward loop then:
import torch
import torch.nn as nn
relu = nn.ReLU()
x = torch.randn(34, 300)
x = eeg_to_volume(x)
print("Volume shape: ", x.shape)
new_x = torch.empty((16, 15, 18, 15, x.shape[-1]))
conv1 = nn.ConvTranspose3d(1, 16, (5, 10, 11))
for t in range(x.shape[-1]):
x_t = x[:,:,:,t]
new_x_t = relu(conv1(x_t.unsqueeze(0))) #add channel dim
new_x[:,:,:,:,t] = new_x_t
x = new_x
print("Shape post-expansion: ", x.shape)
new_x = torch.empty((16, 15, 18, 15, x.shape[-1]))
conv2 = nn.ConvTranspose3d(16, 16, (1, 1, 1))
for t in range(x.shape[-1]):
x_t = x[:,:,:,:,t]
new_x_t = relu(conv2(x_t.unsqueeze(0))) #add channel dim
new_x[:,:,:,:,t] = new_x_t
x = new_x
print("Shape post spatial conv 1: ", x.shape)
new_x = torch.empty((16, 15, 18, 15, x.shape[-1]))
conv3 = nn.ConvTranspose3d(16, 16, (1, 1, 1))
for t in range(x.shape[-1]):
x_t = x[:,:,:,:,t]
new_x_t = relu(conv3(x_t.unsqueeze(0))) #add channel dim
new_x[:,:,:,:,t] = new_x_t
x = new_x
print("Shape post spatial conv 2: ", x.shape)
new_x = torch.empty((16, 15, 18, 15, x.shape[-1]))
conv4 = nn.ConvTranspose3d(16, 16, (1, 1, 1))
for t in range(x.shape[-1]):
x_t = x[:,:,:,:,t]
new_x_t = relu(conv4(x_t.unsqueeze(0))) #add channel dim
new_x[:,:,:,:,t] = new_x_t
x = new_x
print("Shape post spatial conv 3: ", x.shape)
new_x = torch.empty((1, 15, 18, 15, x.shape[-1]))
conv5 = nn.ConvTranspose3d(16, 1, (1, 1, 1))
for t in range(x.shape[-1]):
x_t = x[:,:,:,:,t]
new_x_t = relu(conv5(x_t.unsqueeze(0))) #add channel dim
new_x[:,:,:,:,t] = new_x_t
x = new_x
print("Shape post-shrinking: ", x.shape)
"""
Output:
Volume shape: torch.Size([11, 9, 5, 300])
Shape post-expansion: torch.Size([16, 15, 18, 15, 300])
Shape post spatial conv 1: torch.Size([16, 15, 18, 15, 300])
Shape post spatial conv 2: torch.Size([16, 15, 18, 15, 300])
Shape post spatial conv 3: torch.Size([16, 15, 18, 15, 300])
Shape post-shrinking: torch.Size([1, 15, 18, 15, 300])
"""
I’m suspicious about:
(1) doing time points individually, seems inefficient, is this what the paper suggests?
(2) do I need to do conv(x)
or relu(conv(x))
, relu
right?
Appendix:
Function for assigning EEG data to a volume:
import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
"""
Roughly correct, based on Fig.10, Appendix E.
"""
channel_mapping = {
1: {"name": "Fp1", "coords": (2, 9, 2)},
2: {"name": "Fp2", "coords": (10, 9, 2)},
3: {"name": "AF3", "coords": (4, 8, 3)},
4: {"name": "AF4", "coords": (8, 8, 3)},
5: {"name": "F7", "coords": (2, 7, 2)},
6: {"name": "F3", "coords": (4, 7, 4)},
7: {"name": "Fz", "coords": (6, 7, 4)},
8: {"name": "F4", "coords": (8, 7, 4)},
9: {"name": "F8", "coords": (10, 7, 2)},
10: {"name": "FC5", "coords": (3, 6, 3)},
11: {"name": "FC1", "coords": (5, 6, 5)},
12: {"name": "FC2", "coords": (7, 6, 5)},
13: {"name": "FC6", "coords": (9, 6, 3)},
14: {"name": "T7", "coords": (2, 5, 2)},
15: {"name": "C4", "coords": (7, 5, 4)},
16: {"name": "Cz", "coords": (6, 5, 5)},
17: {"name": "C3", "coords": (5, 5, 4)},
18: {"name": "T8", "coords": (10, 5, 2)},
19: {"name": "CP5", "coords": (3, 4, 3)},
20: {"name": "CP1", "coords": (5, 4, 5)},
21: {"name": "CP2", "coords": (7, 4, 5)},
22: {"name": "CP6", "coords": (9, 4, 3)},
23: {"name": "P7", "coords": (2, 3, 2)},
24: {"name": "P3", "coords": (4, 3, 4)},
25: {"name": "Pz", "coords": (6, 3, 4)},
26: {"name": "P4", "coords": (8, 3, 4)},
27: {"name": "P8", "coords": (10, 3, 2)},
28: {"name": "P07", "coords": (4, 2, 2)},
29: {"name": "P03", "coords": (4, 2, 3)},
30: {"name": "P04", "coords": (8, 2, 3)},
31: {"name": "P08", "coords": (8, 2, 2)},
32: {"name": "O1", "coords": (5, 1, 2)},
33: {"name": "Oz", "coords": (6, 1, 2)},
34: {"name": "O2", "coords": (7, 1, 2)},
}
def eeg_to_volume(eeg_data):
if isinstance(eeg_data, np.ndarray):
eeg_data = torch.from_numpy(eeg_data)
assert eeg_data.dim() == 2
num_channels, num_timepoints = eeg_data.shape
volume = torch.zeros((11, 9, 5, num_timepoints), device=eeg_data.device)
for channel_ix, channel_data in channel_mapping.items():
x, y, z = channel_data["coords"]
volume[x-1, y-1, z-1, :] = eeg_data[channel_ix-1, :]
return volume
def visualize_eeg_volume(channel_mapping):
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection="3d")
# Plot the volume boundaries
ax.plot([0, 11, 11, 0, 0], [0, 0, 9, 9, 0], [0, 0, 0, 0, 0], "k-")
ax.plot([0, 11, 11, 0, 0], [0, 0, 9, 9, 0], [5, 5, 5, 5, 5], "k-")
ax.plot([0, 0], [0, 0], [0, 5], "k-")
ax.plot([11, 11], [0, 0], [0, 5], "k-")
ax.plot([11, 11], [9, 9], [0, 5], "k-")
ax.plot([0, 0], [9, 9], [0, 5], "k-")
# Plot electrode positions
for channel_ix, channel_data in channel_mapping.items():
x, y, z = channel_data["coords"]
ax.scatter(x, y, z-1, c="r", s=100)
ax.text(x, y, z-1, channel_data["name"], fontsize=8)
# Set labels and title
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.set_title("EEG Electrode Positions in 3D Volume (12x10x6)")
# Set axis limits
ax.set_xlim(0, 11)
ax.set_ylim(0, 9)
ax.set_zlim(0, 5)
# Adjust the view angle
ax.view_init(elev=20, azim=45)
plt.tight_layout()
plt.show()
if __name__ == '__main__':
num_channels = 34
num_timepoints = 1000
dummy_eeg_data = np.random.rand(num_channels, num_timepoints)
channel_names = list(channel_mapping.keys()) # Assuming all channels are present
volume_data = eeg_to_volume(dummy_eeg_data)
print(volume_data.shape) # Should print (11, 9, 5, 1000)