// 导入必要的库
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtSession;
import ai.onnxruntime.OrtSession.Result;
import ai.onnxruntime.OrtException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
public class JavaONNXExample {
public static void main(String[] args) {
// 创建ONNX Runtime环境
OrtEnvironment env = OrtEnvironment.getEnvironment();
try {
// 加载ONNX模型
OrtSession session = env.createSession("model.onnx");
// 准备输入数据
float[][] input = {{1.0f, 2.0f}, {3.0f, 4.0f}};
OnnxTensor tensor = OnnxTensor.createTensor(env, input);
// 创建输入映射
Map<String, ? extends OnnxTensor> inputs = new HashMap<>();
inputs.put("input", tensor);
// 运行推理
Result result = session.run(inputs);
// 获取输出结果
float[][] output = (float[][]) result.get(0).getValue();
// 打印输出结果
System.out.println(Arrays.deepToString(output));
// 关闭会话和环境
session.close();
env.close();
} catch (OrtException e) {
e.printStackTrace();
}
}
}
OrtEnvironment.getEnvironment()创建一个ONNX Runtime环境。env.createSession("model.onnx")加载ONNX模型文件。OnnxTensor.createTensor将其转换为ONNX张量。Map中,键为模型的输入名称。session.run(inputs)执行推理操作。Result对象中提取输出数据,并打印出来。上一篇:java map获取key值
下一篇:java 包含字符串
Laravel PHP 深圳智简公司。版权所有©2023-2043 LaravelPHP 粤ICP备2021048745号-3
Laravel 中文站