RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor)

I am trying to implement CUDA for this graph neural network model “https://github.com/VincLee8188/GMAN-PyTorch” to use the nvidia GPU to be able to speed up the process since with the CPU it takes about 10 hours, all this I am using kaggle and I wanted to be able to use their GPU. I started to implement it but I get this error as soon as the training model is started. I am inserting below the main, model and train. Do you have any ideas on how I can fix this error?

Error: RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

Main:

import argparse
import time
import torch.optim as optim
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import sys
import os
from IPython.display import clear_output


# Add the GMAN project directory to the Python path
current_dir = '/kaggle/working'
gman_dir = os.path.join(current_dir, 'GMAN.')
sys.path.append(gman_dir)

# Debug: Print directory information
print("Current directory:", current_dir)
print("Contents of current directory:", os.listdir(current_dir))
print("GMAN directory:", gman_dir)
if os.path.exists(gman_dir):
    print("Contents of GMAN. directory:", os.listdir(gman_dir))
else:
    print("GMAN. directory does not exist")
print("Python path:", sys.path)

from utils.utils_ import log_string, metric
from utils.utils_ import count_parameters, load_data

from model.model_ import GMAN
from model.train import train
from model.test import test

def parse_args():
    parser = argparse.ArgumentParser(description='GMAN')
    parser.add_argument('--time_slot', type=int, default=5,
                        help='a time step is 5 mins')
    parser.add_argument('--num_his', type=int, default=12,
                        help='history steps')
    parser.add_argument('--num_pred', type=int, default=12,
                        help='prediction steps')
    parser.add_argument('--L', type=int, default=1,
                        help='number of STAtt Blocks')
    parser.add_argument('--K', type=int, default=8,
                        help='number of attention heads')
    parser.add_argument('--d', type=int, default=8,
                        help='dims of each head attention outputs')
    parser.add_argument('--train_ratio', type=float, default=0.7,
                        help='training set [default : 0.7]')
    parser.add_argument('--val_ratio', type=float, default=0.1,
                        help='validation set [default : 0.1]')
    parser.add_argument('--test_ratio', type=float, default=0.2,
                        help='testing set [default : 0.2]')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='batch size')
    parser.add_argument('--max_epoch', type=int, default=1,
                        help='epoch to run')
    parser.add_argument('--patience', type=int, default=10,
                        help='patience for early stop')
    parser.add_argument('--learning_rate', type=float, default=0.001,
                        help='initial learning rate')
    parser.add_argument('--decay_epoch', type=int, default=10,
                        help='decay epoch')
    parser.add_argument('--traffic_file', default='/kaggle/working/GMAN./data/pems-bay.h5',
                        help='traffic file')
    parser.add_argument('--SE_file', default='/kaggle/working/GMAN./data/SE(PeMS).txt',
                        help='spatial embedding file')
    parser.add_argument('--model_file', default='/kaggle/working/GMAN./data/GMAN.pkl',
                        help='save the model to disk')
    parser.add_argument('--log_file', default='/kaggle/working/GMAN./data/log',
                        help='log file')
    
    # Parse only known args to avoid conflicts with IPython's arguments
    args, unknown = parser.parse_known_args()
    return args

def main():
    args = parse_args()
    
    # Redirect stdout to capture output
    class OutputCapture:
        def __init__(self):
            self.value = ""
        def write(self, string):
            self.value += string
            sys.__stdout__.write(string)
            sys.__stdout__.flush()
        def flush(self):
            sys.__stdout__.flush()

    output_capture = OutputCapture()
    sys.stdout = output_capture

    log = open(args.log_file, 'w')
    log_string(log, str(args)[10: -1])
    T = 24 * 60 // args.time_slot  # Number of time steps in one day

     # Check if CUDA is available and set the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    log_string(log, f"Using device: {device}")

    # load data
    log_string(log, 'loading data...')
    (trainX, trainTE, trainY, valX, valTE, valY, testX, testTE,
     testY, SE, mean, std) = load_data(args)
    log_string(log, f'trainX: {trainX.shape}tt trainY: {trainY.shape}')
    log_string(log, f'valX:   {valX.shape}ttvalY:   {valY.shape}')
    log_string(log, f'testX:   {testX.shape}tttestY:   {testY.shape}')
    log_string(log, f'mean:   {mean:.4f}ttstd:   {std:.4f}')
    log_string(log, 'data loaded!')

    # build model
    log_string(log, 'compiling model...')
    model = GMAN(SE, args, bn_decay=0.1).to(device)
    loss_criterion = nn.MSELoss().to(device)
    optimizer = optim.Adam(model.parameters(), args.learning_rate)
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=args.decay_epoch,
                                          gamma=0.9)
    parameters = count_parameters(model)
    log_string(log, 'trainable parameters: {:,}'.format(parameters))

    # train model
    start = time.time()
    loss_train, loss_val = train(model, args, log, loss_criterion, optimizer, scheduler, device)
    log_string(log, 'Training completed')

    # test model
    log_string(log, 'Testing model...')
    trainPred, valPred, testPred = test(args, log, device)
    
    # Print final results
    train_mae, train_rmse, train_mape = metric(trainPred, trainY)
    val_mae, val_rmse, val_mape = metric(valPred, valY)
    test_mae, test_rmse, test_mape = metric(testPred, testY)
    log_string(log, '                MAEttRMSEttMAPE')
    log_string(log, 'train            %.2ftt%.2ftt%.2f%%' %
               (train_mae, train_rmse, train_mape * 100))
    log_string(log, 'val              %.2ftt%.2ftt%.2f%%' %
               (val_mae, val_rmse, val_mape * 100))
    log_string(log, 'test             %.2ftt%.2ftt%.2f%%' %
               (test_mae, test_rmse, test_mape * 100))

    # Print performance for each prediction step
    log_string(log, 'Performance in each prediction step')
    for step in range(args.num_pred):
        mae, rmse, mape = metric(testPred[:, step], testY[:, step])
        log_string(log, f'Step {step + 1:02d}: MAE {mae:.2f}, RMSE {rmse:.2f}, MAPE {mape * 100:.2f}%')

    end = time.time()
    log_string(log, f'Total time: {(end - start) / 60:.1f} minutes')
    log.close()

    # Reset stdout
    sys.stdout = sys.__stdout__

    # Clear output and print final results
    clear_output(wait=True)
    print("Execution completed. Showing results:")
    print(output_capture.value)

if __name__ == '__main__':
    main()

Train:

import time
import datetime
import torch
from utils.utils_ import log_string
from model.model_ import *
from utils.utils_ import load_data

def train(model, args, log, loss_criterion, optimizer, scheduler, device):
    model = model.to(device)

    (trainX, trainTE, trainY, valX, valTE, valY, testX, testTE,
     testY, SE, mean, std) = load_data(args)

    # Move data to the specified device
    trainX, trainTE, trainY = trainX.to(device), trainTE.to(device), trainY.to(device)
    valX, valTE, valY = valX.to(device), valTE.to(device), valY.to(device)
    SE = SE.to(device)
    mean, std = mean.to(device), std.to(device)

    num_train, _, num_vertex = trainX.shape
    log_string(log, '**** training model ****')
    num_val = valX.shape[0]
    train_num_batch = math.ceil(num_train / args.batch_size)
    val_num_batch = math.ceil(num_val / args.batch_size)

    wait = 0
    val_loss_min = float('inf')
    best_model_wts = None
    train_total_loss = []
    val_total_loss = []

    # Train & validation
    for epoch in range(args.max_epoch):
        if wait >= args.patience:
            log_string(log, f'early stop at epoch: {epoch:04d}')
            break
        # shuffle
        permutation = torch.randperm(num_train)
        trainX = trainX[permutation]
        trainTE = trainTE[permutation]
        trainY = trainY[permutation]
        # train
        start_train = time.time()
        model.train()
        train_loss = 0
        for batch_idx in range(train_num_batch):
            start_idx = batch_idx * args.batch_size
            end_idx = min(num_train, (batch_idx + 1) * args.batch_size)
            X = trainX[start_idx: end_idx]
            TE = trainTE[start_idx: end_idx]
            label = trainY[start_idx: end_idx]
            optimizer.zero_grad()
            pred = model(X, TE)
            pred = pred * std.to(device) + mean.to(device)
            loss_batch = loss_criterion(pred, label)
            train_loss += float(loss_batch) * (end_idx - start_idx)
            loss_batch.backward()
            optimizer.step()
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            if (batch_idx+1) % 5 == 0:
                print(f'Training batch: {batch_idx+1} in epoch:{epoch}, training batch loss:{loss_batch:.4f}')
            del X, TE, label, pred, loss_batch
        train_loss /= num_train
        train_total_loss.append(train_loss)
        end_train = time.time()

        # val loss
        start_val = time.time()
        val_loss = 0
        model.eval()
        with torch.no_grad():
            for batch_idx in range(val_num_batch):
                start_idx = batch_idx * args.batch_size
                end_idx = min(num_val, (batch_idx + 1) * args.batch_size)
                X = valX[start_idx: end_idx]
                TE = valTE[start_idx: end_idx]
                label = valY[start_idx: end_idx]
                pred = model(X, TE)
                pred = pred * std.to(device) + mean.to(device)
                loss_batch = loss_criterion(pred, label)
                val_loss += loss_batch * (end_idx - start_idx)
                del X, TE, label, pred, loss_batch
        val_loss /= num_val
        val_total_loss.append(val_loss)
        end_val = time.time()
        log_string(
            log,
            '%s | epoch: %04d/%d, training time: %.1fs, inference time: %.1fs' %
            (datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), epoch + 1,
             args.max_epoch, end_train - start_train, end_val - start_val))
        log_string(
            log, f'train loss: {train_loss:.4f}, val_loss: {val_loss:.4f}')
        if val_loss <= val_loss_min:
            log_string(
                log,
                f'val loss decrease from {val_loss_min:.4f} to {val_loss:.4f}, saving model to {args.model_file}')
            wait = 0
            val_loss_min = val_loss
            best_model_wts = model.state_dict()
        else:
            wait += 1
        scheduler.step()

    model.load_state_dict(best_model_wts)
    torch.save(model, args.model_file)
    log_string(log, f'Training and validation are completed, and model has been stored as {args.model_file}')
    return train_total_loss, val_total_loss

Model:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class conv2d_(nn.Module):
    def __init__(self, input_dims, output_dims, kernel_size, stride=(1, 1),
                 padding='SAME', use_bias=True, activation=F.relu,
                 bn_decay=None):
        super(conv2d_, self).__init__()
        self.activation = activation
        if padding == 'SAME':
            self.padding_size = math.ceil(kernel_size)
        else:
            self.padding_size = [0, 0]
        self.conv = nn.Conv2d(input_dims, output_dims, kernel_size, stride=stride,
                              padding=0, bias=use_bias)
        self.batch_norm = nn.BatchNorm2d(output_dims, momentum=bn_decay)
        torch.nn.init.xavier_uniform_(self.conv.weight)

        if use_bias:
            torch.nn.init.zeros_(self.conv.bias)


    def forward(self, x):
        x = x.permute(0, 3, 2, 1)
        x = F.pad(x, ([self.padding_size[1], self.padding_size[1], self.padding_size[0], self.padding_size[0]]))
        x = self.conv(x)
        x = self.batch_norm(x)
        if self.activation is not None:
            x = F.relu_(x)
        return x.permute(0, 3, 2, 1)


class FC(nn.Module):
    def __init__(self, input_dims, units, activations, bn_decay, use_bias=True):
        super(FC, self).__init__()
        if isinstance(units, int):
            units = [units]
            input_dims = [input_dims]
            activations = [activations]
        elif isinstance(units, tuple):
            units = list(units)
            input_dims = list(input_dims)
            activations = list(activations)
        assert type(units) == list
        self.convs = nn.ModuleList([conv2d_(
            input_dims=input_dim, output_dims=num_unit, kernel_size=[1, 1], stride=[1, 1],
            padding='VALID', use_bias=use_bias, activation=activation,
            bn_decay=bn_decay) for input_dim, num_unit, activation in
            zip(input_dims, units, activations)])

    def forward(self, x):
        for conv in self.convs:
            x = conv(x)
        return x


class STEmbedding(nn.Module):
    '''
    spatio-temporal embedding
    SE:     [num_vertex, D]
    TE:     [batch_size, num_his + num_pred, 2] (dayofweek, timeofday)
    T:      num of time steps in one day
    D:      output dims
    retrun: [batch_size, num_his + num_pred, num_vertex, D]
    '''

    def __init__(self, D, bn_decay):
        super(STEmbedding, self).__init__()
        self.FC_se = FC(
            input_dims=[D, D], units=[D, D], activations=[F.relu, None],
            bn_decay=bn_decay)

        self.FC_te = FC(
            input_dims=[295, D], units=[D, D], activations=[F.relu, None],
            bn_decay=bn_decay)  # input_dims = time step per day + days per week=288+7=295

    def forward(self, SE, TE, T=288):
        # spatial embedding
        SE = SE.unsqueeze(0).unsqueeze(0)
        SE = self.FC_se(SE)
        # temporal embedding
        dayofweek = torch.empty(TE.shape[0], TE.shape[1], 7)
        timeofday = torch.empty(TE.shape[0], TE.shape[1], T)
        for i in range(TE.shape[0]):
            dayofweek[i] = F.one_hot(TE[..., 0][i].to(torch.int64) % 7, 7)
        for j in range(TE.shape[0]):
            timeofday[j] = F.one_hot(TE[..., 1][j].to(torch.int64) % 288, T)
        TE = torch.cat((dayofweek, timeofday), dim=-1)
        TE = TE.unsqueeze(dim=2)
        TE = self.FC_te(TE)
        del dayofweek, timeofday
        return SE + TE


class spatialAttention(nn.Module):
    '''
    spatial attention mechanism
    X:      [batch_size, num_step, num_vertex, D]
    STE:    [batch_size, num_step, num_vertex, D]
    K:      number of attention heads
    d:      dimension of each attention outputs
    return: [batch_size, num_step, num_vertex, D]
    '''

    def __init__(self, K, d, bn_decay):
        super(spatialAttention, self).__init__()
        D = K * d
        self.d = d
        self.K = K
        self.FC_q = FC(input_dims=2 * D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC_k = FC(input_dims=2 * D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC_v = FC(input_dims=2 * D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC = FC(input_dims=D, units=D, activations=F.relu,
                     bn_decay=bn_decay)

    def forward(self, X, STE):
        batch_size = X.shape[0]
        X = torch.cat((X, STE), dim=-1)
        # [batch_size, num_step, num_vertex, K * d]
        query = self.FC_q(X)
        key = self.FC_k(X)
        value = self.FC_v(X)
        # [K * batch_size, num_step, num_vertex, d]
        query = torch.cat(torch.split(query, self.K, dim=-1), dim=0)
        key = torch.cat(torch.split(key, self.K, dim=-1), dim=0)
        value = torch.cat(torch.split(value, self.K, dim=-1), dim=0)
        # [K * batch_size, num_step, num_vertex, num_vertex]
        attention = torch.matmul(query, key.transpose(2, 3))
        attention /= (self.d ** 0.5)
        attention = F.softmax(attention, dim=-1)
        # [batch_size, num_step, num_vertex, D]
        X = torch.matmul(attention, value)
        X = torch.cat(torch.split(X, batch_size, dim=0), dim=-1)  # orginal K, change to batch_size
        X = self.FC(X)
        del query, key, value, attention
        return X


class temporalAttention(nn.Module):
    '''
    temporal attention mechanism
    X:      [batch_size, num_step, num_vertex, D]
    STE:    [batch_size, num_step, num_vertex, D]
    K:      number of attention heads
    d:      dimension of each attention outputs
    return: [batch_size, num_step, num_vertex, D]
    '''

    def __init__(self, K, d, bn_decay, mask=True):
        super(temporalAttention, self).__init__()
        D = K * d
        self.d = d
        self.K = K
        self.mask = mask
        self.FC_q = FC(input_dims=2 * D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC_k = FC(input_dims=2 * D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC_v = FC(input_dims=2 * D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC = FC(input_dims=D, units=D, activations=F.relu,
                     bn_decay=bn_decay)

    def forward(self, X, STE):
        batch_size_ = X.shape[0]
        X = torch.cat((X, STE), dim=-1)
        # [batch_size, num_step, num_vertex, K * d]
        query = self.FC_q(X)
        key = self.FC_k(X)
        value = self.FC_v(X)
        # [K * batch_size, num_step, num_vertex, d]
        query = torch.cat(torch.split(query, self.K, dim=-1), dim=0)
        key = torch.cat(torch.split(key, self.K, dim=-1), dim=0)
        value = torch.cat(torch.split(value, self.K, dim=-1), dim=0)
        # query: [K * batch_size, num_vertex, num_step, d]
        # key:   [K * batch_size, num_vertex, d, num_step]
        # value: [K * batch_size, num_vertex, num_step, d]
        query = query.permute(0, 2, 1, 3)
        key = key.permute(0, 2, 3, 1)
        value = value.permute(0, 2, 1, 3)
        # [K * batch_size, num_vertex, num_step, num_step]
        attention = torch.matmul(query, key)
        attention /= (self.d ** 0.5)
        # mask attention score
        if self.mask:
            batch_size = X.shape[0]
            num_step = X.shape[1]
            num_vertex = X.shape[2]
            mask = torch.ones(num_step, num_step)
            mask = torch.tril(mask)
            mask = torch.unsqueeze(torch.unsqueeze(mask, dim=0), dim=0)
            mask = mask.repeat(self.K * batch_size, num_vertex, 1, 1)
            mask = mask.to(torch.bool)
            attention = torch.where(mask, attention, -2 ** 15 + 1)
        # softmax
        attention = F.softmax(attention, dim=-1)
        # [batch_size, num_step, num_vertex, D]
        X = torch.matmul(attention, value)
        X = X.permute(0, 2, 1, 3)
        X = torch.cat(torch.split(X, batch_size_, dim=0), dim=-1)  # orginal K, change to batch_size
        X = self.FC(X)
        del query, key, value, attention
        return X


class gatedFusion(nn.Module):
    '''
    gated fusion
    HS:     [batch_size, num_step, num_vertex, D]
    HT:     [batch_size, num_step, num_vertex, D]
    D:      output dims
    return: [batch_size, num_step, num_vertex, D]
    '''

    def __init__(self, D, bn_decay):
        super(gatedFusion, self).__init__()
        self.FC_xs = FC(input_dims=D, units=D, activations=None,
                        bn_decay=bn_decay, use_bias=False)
        self.FC_xt = FC(input_dims=D, units=D, activations=None,
                        bn_decay=bn_decay, use_bias=True)
        self.FC_h = FC(input_dims=[D, D], units=[D, D], activations=[F.relu, None],
                       bn_decay=bn_decay)

    def forward(self, HS, HT):
        XS = self.FC_xs(HS)
        XT = self.FC_xt(HT)
        z = torch.sigmoid(torch.add(XS, XT))
        H = torch.add(torch.mul(z, HS), torch.mul(1 - z, HT))
        H = self.FC_h(H)
        del XS, XT, z
        return H


class STAttBlock(nn.Module):
    def __init__(self, K, d, bn_decay, mask=False):
        super(STAttBlock, self).__init__()
        self.spatialAttention = spatialAttention(K, d, bn_decay)
        self.temporalAttention = temporalAttention(K, d, bn_decay, mask=mask)
        self.gatedFusion = gatedFusion(K * d, bn_decay)

    def forward(self, X, STE):
        HS = self.spatialAttention(X, STE)
        HT = self.temporalAttention(X, STE)
        H = self.gatedFusion(HS, HT)
        del HS, HT
        return torch.add(X, H)


class transformAttention(nn.Module):
    '''
    transform attention mechanism
    X:        [batch_size, num_his, num_vertex, D]
    STE_his:  [batch_size, num_his, num_vertex, D]
    STE_pred: [batch_size, num_pred, num_vertex, D]
    K:        number of attention heads
    d:        dimension of each attention outputs
    return:   [batch_size, num_pred, num_vertex, D]
    '''

    def __init__(self, K, d, bn_decay):
        super(transformAttention, self).__init__()
        D = K * d
        self.K = K
        self.d = d
        self.FC_q = FC(input_dims=D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC_k = FC(input_dims=D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC_v = FC(input_dims=D, units=D, activations=F.relu,
                       bn_decay=bn_decay)
        self.FC = FC(input_dims=D, units=D, activations=F.relu,
                     bn_decay=bn_decay)

    def forward(self, X, STE_his, STE_pred):
        batch_size = X.shape[0]
        # [batch_size, num_step, num_vertex, K * d]
        query = self.FC_q(STE_pred)
        key = self.FC_k(STE_his)
        value = self.FC_v(X)
        # [K * batch_size, num_step, num_vertex, d]
        query = torch.cat(torch.split(query, self.K, dim=-1), dim=0)
        key = torch.cat(torch.split(key, self.K, dim=-1), dim=0)
        value = torch.cat(torch.split(value, self.K, dim=-1), dim=0)
        # query: [K * batch_size, num_vertex, num_pred, d]
        # key:   [K * batch_size, num_vertex, d, num_his]
        # value: [K * batch_size, num_vertex, num_his, d]
        query = query.permute(0, 2, 1, 3)
        key = key.permute(0, 2, 3, 1)
        value = value.permute(0, 2, 1, 3)
        # [K * batch_size, num_vertex, num_pred, num_his]
        attention = torch.matmul(query, key)
        attention /= (self.d ** 0.5)
        attention = F.softmax(attention, dim=-1)
        # [batch_size, num_pred, num_vertex, D]
        X = torch.matmul(attention, value)
        X = X.permute(0, 2, 1, 3)
        X = torch.cat(torch.split(X, batch_size, dim=0), dim=-1)
        X = self.FC(X)
        del query, key, value, attention
        return X


class GMAN(nn.Module):
    '''
    GMAN
        X:       [batch_size, num_his, num_vertx]
        TE:      [batch_size, num_his + num_pred, 2] (time-of-day, day-of-week)
        SE:      [num_vertex, K * d]
        num_his: number of history steps
        num_pred:number of prediction steps
        T:       one day is divided into T steps
        L:       number of STAtt blocks in the encoder/decoder
        K:       number of attention heads
        d:       dimension of each attention head outputs
        return:  [batch_size, num_pred, num_vertex]
    '''

    def __init__(self, SE, args, bn_decay):
        super(GMAN, self).__init__()
        L = args.L
        K = args.K
        d = args.d
        D = K * d
        self.num_his = args.num_his
        self.SE = SE.to(SE.device)  # Ensure SE is on the correct device
        self.STEmbedding = STEmbedding(D, bn_decay)
        self.STAttBlock_1 = nn.ModuleList([STAttBlock(K, d, bn_decay) for _ in range(L)])
        self.STAttBlock_2 = nn.ModuleList([STAttBlock(K, d, bn_decay) for _ in range(L)])
        self.transformAttention = transformAttention(K, d, bn_decay)
        self.FC_1 = FC(input_dims=[1, D], units=[D, D], activations=[F.relu, None],
                       bn_decay=bn_decay)
        self.FC_2 = FC(input_dims=[D, D], units=[D, 1], activations=[F.relu, None],
                       bn_decay=bn_decay)

    def forward(self, X, TE):
        # Move input tensors to the same device as the model
        X = X.to(self.SE.device)
        TE = TE.to(self.SE.device)

        # input
        X = torch.unsqueeze(X, -1)
        X = self.FC_1(X)
        # STE
        STE = self.STEmbedding(self.SE, TE)
        STE_his = STE[:, :self.num_his]
        STE_pred = STE[:, self.num_his:]
        # encoder
        for net in self.STAttBlock_1:
            X = net(X, STE_his)
        # transAtt
        X = self.transformAttention(X, STE_his, STE_pred)
        # decoder
        for net in self.STAttBlock_2:
            X = net(X, STE_pred)
        # output
        X = self.FC_2(X)
        del STE, STE_his, STE_pred
        return torch.squeeze(X, 3)

Trang chủ Giới thiệu Sinh nhật bé trai Sinh nhật bé gái Tổ chức sự kiện Biểu diễn giải trí Dịch vụ khác Trang trí tiệc cưới Tổ chức khai trương Tư vấn dịch vụ Thư viện ảnh Tin tức - sự kiện Liên hệ Chú hề sinh nhật Trang trí YEAR END PARTY công ty Trang trí tất niên cuối năm Trang trí tất niên xu hướng mới nhất Trang trí sinh nhật bé trai Hải Đăng Trang trí sinh nhật bé Khánh Vân Trang trí sinh nhật Bích Ngân Trang trí sinh nhật bé Thanh Trang Thuê ông già Noel phát quà Biểu diễn xiếc khỉ Xiếc quay đĩa Dịch vụ tổ chức sự kiện 5 sao Thông tin về chúng tôi Dịch vụ sinh nhật bé trai Dịch vụ sinh nhật bé gái Sự kiện trọn gói Các tiết mục giải trí Dịch vụ bổ trợ Tiệc cưới sang trọng Dịch vụ khai trương Tư vấn tổ chức sự kiện Hình ảnh sự kiện Cập nhật tin tức Liên hệ ngay Thuê chú hề chuyên nghiệp Tiệc tất niên cho công ty Trang trí tiệc cuối năm Tiệc tất niên độc đáo Sinh nhật bé Hải Đăng Sinh nhật đáng yêu bé Khánh Vân Sinh nhật sang trọng Bích Ngân Tiệc sinh nhật bé Thanh Trang Dịch vụ ông già Noel Xiếc thú vui nhộn Biểu diễn xiếc quay đĩa Dịch vụ tổ chức tiệc uy tín Khám phá dịch vụ của chúng tôi Tiệc sinh nhật cho bé trai Trang trí tiệc cho bé gái Gói sự kiện chuyên nghiệp Chương trình giải trí hấp dẫn Dịch vụ hỗ trợ sự kiện Trang trí tiệc cưới đẹp Khởi đầu thành công với khai trương Chuyên gia tư vấn sự kiện Xem ảnh các sự kiện đẹp Tin mới về sự kiện Kết nối với đội ngũ chuyên gia Chú hề vui nhộn cho tiệc sinh nhật Ý tưởng tiệc cuối năm Tất niên độc đáo Trang trí tiệc hiện đại Tổ chức sinh nhật cho Hải Đăng Sinh nhật độc quyền Khánh Vân Phong cách tiệc Bích Ngân Trang trí tiệc bé Thanh Trang Thuê dịch vụ ông già Noel chuyên nghiệp Xem xiếc khỉ đặc sắc Xiếc quay đĩa thú vị
Trang chủ Giới thiệu Sinh nhật bé trai Sinh nhật bé gái Tổ chức sự kiện Biểu diễn giải trí Dịch vụ khác Trang trí tiệc cưới Tổ chức khai trương Tư vấn dịch vụ Thư viện ảnh Tin tức - sự kiện Liên hệ Chú hề sinh nhật Trang trí YEAR END PARTY công ty Trang trí tất niên cuối năm Trang trí tất niên xu hướng mới nhất Trang trí sinh nhật bé trai Hải Đăng Trang trí sinh nhật bé Khánh Vân Trang trí sinh nhật Bích Ngân Trang trí sinh nhật bé Thanh Trang Thuê ông già Noel phát quà Biểu diễn xiếc khỉ Xiếc quay đĩa
Thiết kế website Thiết kế website Thiết kế website Cách kháng tài khoản quảng cáo Mua bán Fanpage Facebook Dịch vụ SEO Tổ chức sinh nhật