In all my functions that allow an overlapping union of types, I already have special handling that checks all parameters to ensure that there are no types which don’t support operations with each other.
In my specific use-case, I’m writing functions that accept any of int
, float
, fractions.Fraction
, and decimal.Decimal
as input arguments, and since decimal.Decimal
only supports mathematical operations with instances of itself and int
, I wrote a function that implements the afore-mentioned special handling by first checking if there’s any argument which is a decimal.Decimal
object. And if there is, then the function re-scans all arguments to confirm that all of them are either decimal.Decimal
or int
. If it finds a violation, then it returns False
to notify the caller that incompatible arguments are present (else it returns True
). Below are the two functions that implement this check. All my functions & methods call the latter in a guard-clause at some point, and raise a TypeError
if the return value is False
:
from fractions import Fraction
import collections.abc
import decimal as dec
import typing as tp
Real = tp.Union[int, float, Fraction]
DecLike = tp.Union[int, dec.Decimal]
def _check_has_decimals(args: tp.Tuple[tp.Any, ...]) -> collections.abc.Iterator[bool]:
"""
Returns an iterator yielding whether each argument is an instance of 'decimal.Decimal'.
Can be passed to 'any()'.
"""
return map(lambda param: isinstance(param, dec.Decimal), args)
def _check_arg_types(
*args: tp.Any,
has_decimals: tp.Union[tp.Callable[[tp.Tuple[tp.Any, ...]],
collections.abc.Iterator[bool]],
bool] = _check_has_decimals
) -> bool:
"""Returns a boolean indicating if the types of the arguments are valid."""
if not isinstance(has_decimals, bool):
# Expand iterator to a simple 'bool'.
has_decimals = any(has_decimals(args))
if has_decimals:
return all(map(lambda param: isinstance(param, tp.get_args(DecLike)), args))
else:
return all(map(lambda param: isinstance(param, tp.get_args(Real)), args))
The reason for the callable default argument is so if I’m writing a class that accepts mixed arguments, it can manually call _check_has_decimals
and store the result as an instance attribute so that when any of its methods are called, it can supply _check_arg_types
with the pre-computed value instead of letting the function re-compute it every time. But if being called from a function, then that function just calls _check_arg_types
without the optional argument, as the same set of arguments will only have to be checked once anyway.
So in a function, for example, I would do this:
def add_stuff(a, b, c, d):
if not _check_arg_types(a, b, c, d):
err_msg = ("argument types other than 'int' or 'Decimal' aren't allowed if "
"any argument is a 'Decimal'!")
raise TypeError(err_msg)
return a + b + c + d
And now the problem. I can’t get the type checker (PyRight in my case) to understand that I’m preventing the imcompatible types from collisioning in my code! If I simply annotate the above function’s signature as:
Numeric = tp.Union[Real, DecLike]
def add_stuff(a: Numeric, b: Numeric, c: Numeric, d: Numeric) -> Numeric:
...
…then the nested union just gets flattened to tp.Union[int, float, Fraction, dec.Decimal]
, and when I type-check the code, I would get an error on the line return a + b
because the type-checker thinks it’s possible for… say a
to be float
and b
to be Decimal
. And nothing I do can make it understand better. The only solution seems to be to not type annotate any the function/methods in my code that accept overlapping unions, and lose the benefit of static type-checking completely (for those function/method/classes, at least).
And here is the list of things I’ve tried, with help from the Python chatroom. None of them work completely for my use-case:
-
Over-loading a function 2 times, with the first signature’s parameter only accepting
Real
s and with return type annotated asReal
, and the second signature with everything swapped out forDecLike
. Also, makeReal
andDecLike
generic types instead ofUnion
s. It works if you suppress theoverload-overlap
error on the line of the first overload:Real = tp.TypeVar("Real", bound=tp.Union[int, float, Fraction]) DecLike = tp.TypeVar("DecLike", bound=tp.Union[int, dec.Decimal]) @tp.overload def add_stuff(a: Real, b: Real) -> Real: # type: ignore[overload-overlap] ... @tp.overload def add_stuff(a: DecLike, b: DecLike) -> DecLike: ... def add_stuff(a, b): return a + b print(tp.reveal_type(add_stuff(1, 1))) print(tp.reveal_type(add_stuff(1, 1.5))) print(tp.reveal_type(add_stuff(Fraction(1, 2), 0.5))) print(tp.reveal_type(add_stuff(3, dec.Decimal("0.5")))) print(tp.reveal_type(add_stuff(dec.Decimal("1"), dec.Decimal("1"))))
But the problem with this is that it doesn’t work for class instance methods which don’t take any arguments and return a value based on an instance attribute. If I tried to define, say:
class Foo: ... @tp.overload def return_something(self) -> Real: ... @tp.overload def return_something(self) -> DecLike: ... def return_something(self): return self.some_instance_attribute
then it would trigger the error
A function returning TypeVar should receive at least one argument containing the same TypeVar [type-var]
. -
Make the class which supports overlapping unions a generic class, and then create an overloaded helper function which take constructor arguments, use them unchanged to create a new instance of the class, and has its return type annotated as a subscription of the generic class.
T = tp.TypeVar("T") TReal = tp.TypeVar("TReal", bound=tp.Union[int, float, Fraction]) TDec = tp.TypeVar("TDec", bound=tp.Union[int, dec.Decimal]) class C(tp.Generic[T]): def __init__(self, foo: T, bar: T) -> None: self.foo, self.bar = foo, bar def return_something(self) -> T: return self.foo + self.bar @tp.overload def from_number(a: TReal, b: TReal) -> C[TReal]: ... @tp.overload def from_number(a: TDec, b: TDec) -> C[TDec]: ... def from_number(a, b): return C(a, b) print(tp.reveal_type(from_number(1, 1))) print(tp.reveal_type(from_number(1, 1.5))) print(tp.reveal_type(from_number(Fraction(1, 2), 0.5))) print(tp.reveal_type(from_number(3, dec.Decimal("0.5")))) print(tp.reveal_type(from_number(dec.Decimal("1"), dec.Decimal("1"))))
This gives an
Operator "X" not supported for types "T@C" and "T@C"
whenever I try to do any operations on the stored instance attributes. So this is basically the same as just annotating astp.Union[int, float, Fraction, dec.Decimal]
in that I have to add a# type: ignore[reportOperatorIssue]
after every operation between the arguments.
Is there any solution (preferably simple) that actually works for my use case? Or am I trying to do something impossible?