I have been scratching my head over this for the last hour.
Imagine I have a function that takes as argument a number, or an array up to dimension 1. I’d like it to return a scalar (not a 0d array) in the former case, and an array of the same shape in the latter case, like ufuncs do.
The current implementation of my function does something along the lines of
def func(x: Real | np.ndarray, arr: np.ndarray):
"""Illustrate an actually more complicated function"""
return arr @ np.sin(arr[:, None] * x)
and I know arr
is a 1d array. It is promoted to 2d so that there is no broadcast issue with the element-wise multiplication. The issue being that a 1d array is systematically returned. The nicety being that the cases
- x scalar and len(arr) == 1;
- x scalar and len(arr) > 1 ;
- x array and len(arr) == 1 ;
- x array and len(arr) >= 1 and (len(x) != len(arr) or len(x) == len(arr))
are covered in a one-liner (the tautological last statement is here for emphasis).
I tried
@np.vectorize
def func(x, arr):
return arr @ np.sin(arr * x)
which consistently returns a 1d array as well, and is probably not the best performance-wise.
I looked at functools.singledispatch
, which leads to much duplication, and I’d probably forget about some corner cases.
A solution would be
def func(x, arr):
res = arr @ np.sin(arr[:, None] * x)
if len(res) == 1:
return res.item()
return res
but I have many such functions, and it doesn’t feel very pythonic? I can write a decorator to write this check,
def give_me_a_scalar(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
res = f(*args, **kwargs)
if len(res) == 1:
return res.item()
return res
return wrapper
seems to be doing what I want, but I struggle to believe nothing like that already exists. Am I missing something simpler?