The problem
I have a common dataclass that contains two numpy arrays:
@dataclass
class VectorCovariancePair:
data: np.ndarray[np.float64] # expected size: N
covariance: np.ndarray[np.float64] # expected size: N * N
Multiple other classes use this to represent values of different sizes, for example:
@dataclass
class Pose:
position: VectorCovariancePair # expected size: 3
orientation: VectorCovariancePair # expected size: 4
# ... and other classes ...
Each of these has a method which i use to construct the corresponding message. Here, the message constructor validates the vector size and throws an error. I’ll replace it with an assertion.
class Pose:
# ...
def to_msg(self):
assert self.position.data.size == 3
msg = # ...
return msg
My program is structured in such a way that i construct the object in a module and use to_msg in a completely different module.
I now constructed some of these object with the wrong size; to_msg
will throw, but nothing happens on construction. This results in me not being easily able to find where exactly i constructed the object with the wrong size.
If the size were validated on construction, an exception would be thrown there instead, and i’d be immediately able to find the problem.
Of course, i could validate each VectorCovariancePair
instance directly in the classes that use them:
class Pose:
# ...
def __post_init__(self):
assert self.position.data.size == 3
assert self.orientation.data.size == 4
but this would result in a lot of boilerplate and repeated code.
My attempt
I tried to imitate what i’d do in C++, and just strap a generic parameter onto the class. I looked up how to use typing.Generic
, and i came up with this:
from typing import TypeVar, Generic
N = TypeVar('N', bound=int)
@dataclass
class VectorCovariancePair(Generic[N]):
# ...
def __post_init__(self):
n = retrieve_N(self)
assert self.data.size == n
assert self.covariance.size == n * n
Which would result in very elegant code, with validation being located exclusively in VectorCovariancePair
:
@dataclass
class Pose:
position: VectorCovariancePair[3]
orientation: VectorCovariancePair[4]
Unfortunately i fail to retrieve the argument. I tried both of these, but i get the same output:
def retrieve_n(obj):
print(obj.__orig_bases__[0].__args__) # (~N,)
print(typing.get_args(obj.__orig_bases__[0])) # (~N,)
I’d guess the reason is that this information is not contained in the instances of VectorCovariancePair
. How would you approach this problem?