In pytorch (Lightning) I would like to train parameters resulting from a class. However I have multiple instances and some of them have a functional dependence on one another: Is it simply possible to combine those in the forward ?
Is the following possible/ correct? Or will the behavior be different from what I think?
In pytorch
class A(Module):
def __init__(self):
super().__init__()
self.oe_11 = OEE(1)
self.oe_22 = OEE(1)
self.oe_12 = OEE(1)
def forward(self, values,...):
value_types=torch.zeros_like(values)
contribs_11=self.oe_11(value_types, values)
contribs_22=self.oe_22(value_types, values)
contribs_12=self.oe_12(value_types, values)
contribs_total=(contribs_11 +contribs_22 +contribs_12)
return contribs_total
would result in two (trainable) variables for each of the OEEnergy (6 in total)
class OEE(tb):
def __init__(self, type_count: int) -> None:
super().__init__(type_count, 2)
def get_contribution(self, parameters: Tensor, val_d: Tensor) -> Tensor:
p1 = parameters[0]
p2 = parameters[1]
#function of p1 and p2, product is only for simplification
return p1*p2
Where tb is just a more generalized class that has an abstractmethod of
@abstractmethod
def get_contribution(self, parameters: Tensor, val_d: Tensor) -> Tensor: ...
so that multiple subclasses can be used for different contributions. (I think that is not really relevant I think.) Restricting myself only to the single subclass of OEE, I want the parameters of the 3rd set to depend on the first two sets.
That means I had the idea of inserting the following additional block in the forward method
params_11 = list(self.self.oe_11.parameters())
params_22 = list(self.self.oe_22.parameters())
params_12 = list(self.self.oe_12.parameters())
params_12[0].data = (params_11[0].data * params_22[0].data)
before I run
contribs_12=self.oe_12(value_types, values)
Will that result in a set of 6 trainable parameters for the model where 2 sets of 2 are independent and the 3rd set of 2 dependent on the others or am I doing something else?