Let me first try to explain the complicated scenario. In astronomy, we have data in multiple filters or bands at different time epochs. I am trying to classify time series of images using ConvLSTM2D
, but my question applies to LSTM
as well, if you replace the 2D image by 1D data and the convolution part in ConvLSTM2D
by matmul
(Dense
) operations.
The image below explain the situation for 2 bands and r and g.
The following code works fine for the individual bands separately.
input_shape = (None, 72, 72, 1)
additional_input_shape = (None,1)
image_input = Input(shape=input_shape, name='image_input')
additional_input = Input(shape=additional_input_shape, name='additional_input')
x = ConvLSTM2D(32, (9, 9), activation='relu', padding='valid', return_sequences=True, data_format='channels_last')(image_input)
x = BatchNormalization()(x)
x = TimeDistributed(MaxPooling2D((2, 2), data_format='channels_last'))(x)
x = TimeDistributed(Flatten())(x)
x = TimeDistributed(Dense(64, activation='relu'))(x)
additional_x = LSTM(32, return_sequences=True)(additional_input)
combined_x = Concatenate()([x, additional_x])
combined_x = TimeDistributed(Dense(64, activation='relu'))(combined_x)
output = TimeDistributed(Dense(1, activation='sigmoid'))(combined_x)
Imodel2 = tf.keras.Model(inputs=[image_input, additional_input], outputs=output)
Note: additional_input
gives the values of time stamps or epochs, so that the network knows the gap between two images.
But I wanted to combine the results from the two bands, specifically concatenate combined_x
layers from g and r bands followed by the dense layers. The bottom panel of the above figure shows the images from band r and g on the same time axis. I want to make prediction at any given time. For example, after getting an image in r band, sat at time=t1
, I want to concatenate combined_r
with the latest combined_g
(till t1
). Could you please help me out here?
Please let me know if you need additional information.
Extra complications
Since different samples can have different number of time stamps, I am using dynamic batches
, that means batch the samples that have same no. of time stamps together. We can figure this out once the network part is designed properly.