1 OpenCV 环境的准备
这个项目中需要用到 opencv 进行图片的读取与处理操作,因此我们需要先配置一下 opencv 在 java 中运行的配置。
首先前往 opencv 官网下载 opencv-4.6 :点此下载;下载好后仅选择路径后即可完成安装。
此时将 opencv\build\java\x64
路径下的 opencv_java460.dll
复制到 C:\Windows\System32
中,再将 D:\Tools\opencv\opencv\build\java
下的 opencv-460.jar
放到我们 Springboot 项目 resources 文件夹下的 lib 文件夹下。
本文所需 ONNX 文件请 点此下载 。
2 Maven 配置
引入 onnxruntime 和 opencv 这两个依赖即可。值得注意的是,引 opencv 时systemPath
记得与上文说的opencv-460.jar
所在路径保持一致。
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.12.1</version>
</dependency>
<dependency>
<groupId>org.opencv</groupId>
<artifactId>opencv</artifactId>
<version>4.6.0</version>
<scope>system</scope>
<systemPath>${project.basedir}/src/main/resources/lib/opencv-460.jar</systemPath>
</dependency>
3 Utils
3.1 Letterbox.java
这个类负责调整图像大小和填充图像,使满足步长约束,并记录参数。
package cn.halashuo.yolov8.utils;
import org.opencv.core.Core;
import org.opencv.core.Mat;
import org.opencv.core.Size;
import org.opencv.imgproc.Imgproc;
public class Letterbox {
private final Size newShape = new Size(1280, 1280);
private final double[] color = new double[]{114,114,114};
private final Boolean auto = false;
private final Boolean scaleUp = true;
private final Integer stride = 32;
private double ratio;
private double dw;
private double dh;
public double getRatio() {
return ratio;
}
public double getDw() {
return dw;
}
public Integer getWidth() {
return (int) this.newShape.width;
}
public Integer getHeight() {
return (int) this.newShape.height;
}
public double getDh() {
return dh;
}
public Mat letterbox(Mat im) { // 调整图像大小和填充图像,使满足步长约束,并记录参数
int[] shape = {im.rows(), im.cols()}; // 当前形状 [height, width]
// Scale ratio (new / old)
double r = Math.min(this.newShape.height / shape[0], this.newShape.width / shape[1]);
if (!this.scaleUp) { // 仅缩小,不扩大(一起为了mAP)
r = Math.min(r, 1.0);
}
// Compute padding
Size newUnpad = new Size(Math.round(shape[1] * r), Math.round(shape[0] * r));
double dw = this.newShape.width - newUnpad.width, dh = this.newShape.height - newUnpad.height; // wh 填充
if (this.auto) { // 最小矩形
dw = dw % this.stride;
dh = dh % this.stride;
}
dw /= 2; // 填充的时候两边都填充一半,使图像居于中心
dh /= 2;
if (shape[1] != newUnpad.width || shape[0] != newUnpad.height) { // resize
Imgproc.resize(im, im, newUnpad, 0, 0, Imgproc.INTER_LINEAR);
}
int top = (int) Math.round(dh - 0.1), bottom = (int) Math.round(dh + 0.1);
int left = (int) Math.round(dw - 0.1), right = (int) Math.round(dw + 0.1);
// 将图像填充为正方形
Core.copyMakeBorder(im, im, top, bottom, left, right, Core.BORDER_CONSTANT, new org.opencv.core.Scalar(this.color));
this.ratio = r;
this.dh = dh;
this.dw = dw;
return im;
}
}
3.2 Lable.java
这个类负责记录标签的名称,因为模型输出出来的类是一个坐标,每个坐标对应类名都在这里。同时为了方便管理,每个类画方框时所用颜色也在此随机生成。
package cn.halashuo.yolov8.utils;
import java.util.Random;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
public class Lable {
private final List<String> names = new ArrayList<>(Arrays.asList(
"person", "bicycle", "car", "motorcycle", "airplane", "bus", "train",
"truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter",
"bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear",
"zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase",
"frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat",
"baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple",
"sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut",
"cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
"tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave",
"oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors",
"teddy bear", "hair drier", "toothbrush"));
private final Integer nc = 80;
private final Map<String, double[]> colors;
public Lable() {
this.colors = new HashMap<>();
names.forEach(name->{
Random random = new Random();
double[] color = {random.nextDouble()*256, random.nextDouble()*256, random.nextDouble()*256};
colors.put(name, color);
});
}
public String getName(int clsId) {
return names.get(clsId);
}
public double[] getColor(int clsId) {
return colors.get(getName(clsId));
}
public int getNc() {
return nc;
}
}
3.3 ModelResultV8.java
模型物体识别结果的实体类。
package cn.halashuo.yolov8.utils;
public class ModelResultV8 {
private final float x0;
private final float x1;
private final float y0;
private final float y1;
private final Integer clsId;
private final Integer nc = 80;
private final float score;
public ModelResultV8(float[] result) {
float x = result[0];
float y = result[1];
float w = result[2]/2f;
float h = result[3]/2f;
this.x0 = x-w;
this.x1 = x+w;
this.y0 = y-h;
this.y1 = y+h;
float maxClsScore = result[4];
int clsId = 0;
for (int i=0;i<nc;i++) {
if (result[4+i] >maxClsScore) {
maxClsScore = result[4+i];
clsId = i;
}
}
this.score = maxClsScore;
this.clsId = clsId;
}
public float getX0() {
return x0;
}
public float getX1() {
return x1;
}
public float getY0() {
return y0;
}
public float getY1() {
return y1;
}
public Float getScore() {
return score;
}
public int getClsId() {
return clsId;
}
@Override
public String toString() {
return "检测头: " +
"x0=" + x0 +
", x1=" + x1 +
", y0=" + y0 +
", y1=" + y1 +
", clsId=" + clsId +
", score=" + score
;
}
}
3.4 NMS.java
package cn.halashuo.yolov8.utils;
import java.util.ArrayList;
import java.util.List;
public class NMS {
public static List<ModelResultV8> nms(List<ModelResultV8> boxes, float iouThreshold) {
// 根据score从大到小对List进行排序
boxes.sort((b1, b2) -> Float.compare(b2.getScore(), b1.getScore()));
List<ModelResultV8> resultList = new ArrayList<>();
for (int i = 0; i < boxes.size(); i++) {
ModelResultV8 box = boxes.get(i);
boolean keep = true;
// 从i+1开始,遍历之后的所有boxes,移除与box的IOU大于阈值的元素
for (int j = i + 1; j < boxes.size(); j++) {
ModelResultV8 otherBox = boxes.get(j);
float iou = getIntersectionOverUnion(box, otherBox);
if (iou > iouThreshold) {
keep = false;
break;
}
}
if (keep) {
resultList.add(box);
}
}
return resultList;
}
private static float getIntersectionOverUnion(ModelResultV8 box1, ModelResultV8 box2) {
float x1 = Math.max(box1.getX0(), box2.getX0());
float y1 = Math.max(box1.getY0(), box2.getY0());
float x2 = Math.min(box1.getX1(), box2.getX1());
float y2 = Math.min(box1.getY1(), box2.getY1());
float intersectionArea = Math.max(0, x2 - x1) * Math.max(0, y2 - y1);
float box1Area = (box1.getX1() - box1.getX0()) * (box1.getY1() - box1.getY0());
float box2Area = (box2.getX1() - box2.getX0()) * (box2.getY1() - box2.getY0());
float unionArea = box1Area + box2Area - intersectionArea;
return intersectionArea / unionArea;
}
}
4 YoloV8.java
设置好 ONNX 文件路径及需要识别的图片路径即可。如有需要也可设置 CUDA 作为运行环境,大幅提升 FPS。
package cn.halashuo.yolov8;
import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import cn.halashuo.yolov8.utils.Lable;
import cn.halashuo.yolov8.utils.Letterbox;
import cn.halashuo.yolov8.utils.ModelResultV8;
import cn.halashuo.yolov8.utils.NMS;
import org.opencv.core.*;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.util.*;
public class YoloV8 {
static
{
//在使用OpenCV前必须加载Core.NATIVE_LIBRARY_NAME类,否则会报错
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
public static void main(String[] args) throws Exception{
// 加载ONNX模型
OrtEnvironment environment = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
OrtSession session = environment.createSession("other\\yolov8s.onnx", sessionOptions);
// 输出基本信息
session.getInputInfo().keySet().forEach(x-> {
try {
System.out.println("input name = " + x);
System.out.println(session.getInputInfo().get(x).getInfo().toString());
} catch (OrtException e) {
throw new RuntimeException(e);
}
});
// 加载标签及颜色
Lable lable = new Lable();
float iouThreshold = 0.65f;
float scoreThreshold = 0.1f;
// 读取 image
Mat img = Imgcodecs.imread("other/test.jpg");
Imgproc.cvtColor(img, img, Imgproc.COLOR_BGR2RGB);
Mat image = img.clone();
// 在这里先定义下框的粗细、字的大小、字的类型、字的颜色(按比例设置大小粗细比较好一些)
int minDwDh = Math.min(img.width(), img.height());
int thickness = minDwDh/333;
double fontSize = minDwDh/1145.14;
int fontFace = Imgproc.FONT_HERSHEY_SIMPLEX;
Scalar fontColor = new Scalar(255, 255, 255);
// 更改 image 尺寸
Letterbox letterbox = new Letterbox();
letterbox.setNewShape(new Size(640,640));
letterbox.setStride(64);
image = letterbox.letterbox(image);
double ratio = letterbox.getRatio();
double dw = letterbox.getDw();
double dh = letterbox.getDh();
int rows = letterbox.getHeight();
int cols = letterbox.getWidth();
int channels = image.channels();
// 将Mat对象的像素值赋值给Float[]对象
float[] pixels = new float[channels * rows * cols];
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
double[] pixel = image.get(j,i);
for (int k = 0; k < channels; k++) {
// 这样设置相当于同时做了image.transpose((2, 0, 1))操作
pixels[rows*cols*k+j*cols+i] = (float) pixel[k]/255.0f;
}
}
}
// 创建OnnxTensor对象
long[] shape = { 1L, (long)channels, (long)rows, (long)cols };
OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
// 运行模型
OrtSession.Result output = session.run(stringOnnxTensorHashMap);
// 得到结果
float[][][] outputData = (float[][][]) output.get(0).getValue();
float[][] results= outputData[0];
List<ModelResultV8> modelResults = new ArrayList<>();
for (int i=0;i<8400;i++) {
float[] result = new float[lable.getNc()+4];
for (int j=0;j<lable.getNc()+4;j++) {
result[j] = results[j][i];
ModelResultV8 modelResult = new ModelResultV8(result);
if (modelResult.getScore()>scoreThreshold){
modelResults.add(modelResult);
}
}
}
modelResults = NMS.nms(modelResults, iouThreshold);
modelResults.iterator().forEachRemaining(modelResult->{
System.out.println(modelResult);
// 画框
Point topLeft = new Point((modelResult.getX0()-dw)/ratio, (modelResult.getY0()-dh)/ratio);
Point bottomRight = new Point((modelResult.getX1()-dw)/ratio, (modelResult.getY1()-dh)/ratio);
Scalar color = new Scalar(lable.getColor(modelResult.getClsId()));
Imgproc.rectangle(img, topLeft, bottomRight, color, thickness);
// 框上写文字
String boxName = lable.getName(modelResult.getClsId()) + ": " + modelResult.getScore();
Point boxNameLoc = new Point((modelResult.getX0()-dw)/ratio, (modelResult.getY0()-dh)/ratio-3);
Imgproc.putText(img, boxName, boxNameLoc, fontFace, fontSize, fontColor, thickness);
});
Imgproc.cvtColor(img, img, Imgproc.COLOR_RGB2BGR);
Imgcodecs.imwrite("C:\\Users\\pbh0612\\Desktop\\imagev8.jpg", img);
HighGui.imshow("Display Image", img);
// 等待按下任意键继续执行程序
HighGui.waitKey();
}
}
运行结果:
input name = images
TensorInfo(javaType=FLOAT,onnxType=ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,shape=[1, 3, 640, 640])
检测头: x0=406.35785, x1=520.4848, y0=304.06427, y1=513.8902, clsId=0, score=0.9011777
检测头: x0=187.34064, x1=277.30844, y0=329.81458, y1=516.4944, clsId=0, score=0.8873602
检测头: x0=272.56705, x1=306.77542, y0=208.52237, y1=228.92984, clsId=29, score=0.5956715
检测头: x0=275.60977, x1=370.35672, y0=248.59772, y1=516.78577, clsId=0, score=0.13012922
检测头: x0=0.06917572, x1=10.5588, y0=430.62335, y1=508.3105, clsId=0, score=0.12150359
YOLOv8
的3
个检测头一共有80x80+40x40+20x20=8400
个输出单元格,每个单元格包含x,y,w,h
这4
项再加80
个类别的置信度总共84
项内容,所以使用 YOLOV8 的官方模型训练并转化成 onnx 后,得到的结果维度为 1x84x8400 。我们需要对这些数据设置置信度阈值、IoU阈值等。
5 条评论
大佬牛批,找了n多教程,就大佬的可以start,其它的一个看不懂。。。。全是高级操作,调整代码都没地方下手
大佬,准确度比用yolo命令行出来的低了很多,是什么原因造成的,希望赐教
模型直接用官方的,自己训练的准确率应该就上去了吧,理论上准确度应该是一致的,也可能onnxruntime的数据精度不够产生的误差,我研究研究
大佬能不能给个联系方式请教一下,联系方式可以发邮箱
大佬牛批