I’m attempting to use the onnxruntime
wrapper for Go to inference an ONNX model.
Model visualization:
The model inference works fine when using the onnxruntime
in Python, which does not require specifying tensor sizes. However, when attempting to run in Go, I get the following:
Error running the model: Error running network: Invalid rank for
output: predictions Got: 2 Expected: 1 Please fix either the
inputs/outputs or the model.
package main
import (
"fmt"
ort "github.com/yalue/onnxruntime_go"
)
func main() {
ort.SetSharedLibraryPath("/opt/homebrew/Cellar/onnxruntime/1.17.1/lib/libonnxruntime.dylib")
err := ort.InitializeEnvironment()
defer ort.DestroyEnvironment()
if err == nil {
fmt.Println("Initialized")
}
inputData := [][]float32{
{546.0, 410, 4.63, 0, 8592, 5192, 34728, 9000, 3, 0, 2, 55461, 3, 20, 3},
{386.0, 398, 4.52, 0, 6699, 4217, 34728, 9000, 3, 0, 2, 55253, 3, 20, 3},
{69.0, 2, 5.0, 0, 2191, 1105, 34728, 9000, 3, 0, 2, 55152, 3, 20, 3},
{359.0, 387, 4.48, 0, 5262, 3053, 34728, 9000, 3, 0, 2, 55101, 3, 20, 3},
{213.0, 396, 4.63, 0, 4545, 2614, 34728, 9000, 3, 0, 2, 55013, 3, 20, 3},
{165.0, 398, 4.63, 0, 3715, 2522, 34728, 9000, 3, 0, 2, 54449, 3, 20, 3},
}
rows := len(inputData)
cols := len(inputData[0])
// Flatten the two-dimensional slice into a one-dimensional slice
flatData := make([]float32, 0, rows*cols)
for _, row := range inputData {
flatData = append(flatData, row...)
}
inputShape := ort.NewShape(int64(rows), int64(cols))
inputTensor, err := ort.NewTensor(inputShape, flatData)
if err != nil {
fmt.Println(err)
}
fmt.Println(inputTensor)
outputShape := ort.NewShape(int64(rows), 1)
outputTensor, err := ort.NewEmptyTensor[float32](outputShape)
if err != nil {
fmt.Println("Failed to create output tensor:", err)
return
}
defer outputTensor.Destroy()
// Create the session with the model and specify input/output details
session, err := ort.NewSession[float32]("CatBoost_RMSE.onnx",
[]string{"features"}, []string{"predictions"},
[]*ort.Tensor[float32]{inputTensor}, []*ort.Tensor[float32]{outputTensor})
if err != nil {
fmt.Println("Failed to create session:", err)
return
}
defer session.Destroy()
// Run the model
err = session.Run()
if err != nil {
fmt.Println("Error running the model:", err)
return
}
// Retrieve and print the output data
outputData := outputTensor.GetData()
fmt.Println("Model output data:", outputData)
}
I have tried updating ort.NewShape(int64(rows), 1)
to ort.NewShape(int64(rows))
but that results in a new error:
Error running the model: Error running network: Non-zero status code
returned while running TreeEnsembleRegressor node. Name:” Status
Message:
/tmp/onnxruntime-20240227-5139-v94cng/onnxruntime/core/framework/execution_frame.cc:173
Status
onnxruntime::IExecutionFrame::GetOrCreateNodeOutputMLValue(const int,
int, const TensorShape *, OrtValue *&, const Node &) shape &&
tensor.Shape() == *shape was false. OrtValue shape verification
failed. Current shape:{6} Requested shape:{6,1}