I have a C++ library function addTensor(tensor, type)
in Tensor
class, tensor
being an xarray
and type
being an enum with FP32
, FP16
, INT
, I8
and U8F
as possible types. I also have overloads for each of the types, e.g. addTensorFP32(tensor)
etc. I’m trying to get python bindings so that users may only call n.addTensor(a.astype(np.float16))
to add a tensor. I’ve bind the various overloads such as addTensorF32()
and the other ones no problem, and I’m writing a dispatch function that detects the type of values in a numpy array and calls the right overload based on that. My progress thus far:
.def("addTensor", [](Tensor&obj, py::array tensor){
if(py::isinstance<py::array_t<float>>(tensor) ||
py::isinstance<py::array_t<double>>(tensor)){
obj.addTensorFP32(tensor.cast<xt::xarray<float>>());
} else if(py::isinstance<py::array_t<int>>(tensor)) {
obj.addTensorINT(tensor.cast<xt::xarray<int>>());
} else if(py::isinstance<py::array_t<float16>>(tensor)) { // Problem
obj.addTensorFP16(name, tensor.cast<xt::xarray<float16>>()); // Also problematic
} else {
throw std::runtime_error("Unsupported datatype");
}
})
Checking the instance of numpy arrays works great for C++ built-ins, but doesn’t work for non-standard types such as float16
. Is there an alternative way to check the underlying type? I also tried using dtype
but there’s also no option to detect float16.