I’m using torch
, jaxtyping
and beartype
to type annotate my functions (with runtime type checking). I have used the @jaxtyped(typechecker=beartype)
decorator.
I have received the following error at runtime:
jaxtyping.TypeCheckError: Type-check error whilst checking the parameters of display_features_from_tokens_and_feature_tensor.
The problem arose whilst typechecking parameter 'tokens'.
Actual value: tensor([41083, 531, 366, 31373, 612, 1, 290, 788, 3332, 866,
13], device='cuda:0')
Expected type: <class 'Int[Tensor, 'seq_len']'>.`
As far as I can tell the type of the actual value matches the expected type. Is there some nuance of beartype
or jaxtyping
that I am missing?
I suspected that it may be to do with the fact that my tensors were on the GPU but couldn’t find any mention of this within the jaxtyped
docs.
I also tried using Int64
to no avail (same error).