博客
关于我
pytorch argmax_从Pytorch 的ONNX到OpenVINO中IR中间层
阅读量:386 次
发布时间:2019-03-05

本文共 4199 字,大约阅读时间需要 13 分钟。

cb0238cbf335204d36cbcbcb11c63c89.png

点击上方蓝字关注我们

微信公众号:OpenCV学堂

关注获取更多计算机视觉与深度学习知识

Pytorch ONNX格式支持

ONNX是一种深度学习权重模型的表示格式,ONNX格式可以让AI开发者在不同框架之间相互转换模型,实现调用上的通用性。当前PyTorch*, Caffe2*, Apache MXNet*, Microsoft Cognitive Toolkit* 、百度飞桨都支持ONNX格式。OpenVINO的模型优化器支持把ONNX格式的模型转换IR中间层文件。

当前OpenVINO官方支持的ONNX模型主要包括:bert_large,bvlc_alexnet,bvlc_googlenet,bvlc_reference_caffenet,bvlc_reference_rcnn_ilsvrc13 model,inception_v1,inception_v2,resnet50,squeezenet,densenet121,emotion_ferplus,mnist,shufflenet,VGG19,zfnet512。需要注意的是这些模型升级版本并不被支持。

从OpenVINO的2019R04版本开始支持所有公开的Pytorch模型,支持的模型列表如下:

1670fe7192da4625643918d897e8ab8c.png

Pytorch ONNX到OpenVINO IR转换

下面的例子演示了如何从torchvision的公开模型中转换为ONNX,然后再转换为IR,使用OpenVINO完成调用的完整过程。我们将以resnet18为例来演示。

01

下载模型与转ONNX格式

要下载与使用torchvision的预训练模型,首选需要安装好pytorch,然后执行下面的代码就可以下载相关支持模型:

1import torchvision.models as models  2resnet18 = models.resnet18(pretrained=True)  3alexnet = models.alexnet(pretrained=True)  4squeezenet = models.squeezenet1_0(pretrained=True)  5vgg16 = models.vgg16(pretrained=True)  6densenet = models.densenet161(pretrained=True)  7inception = models.inception_v3(pretrained=True)  8googlenet = models.googlenet(pretrained=True)  9shufflenet = models.shufflenet_v2_x1_0(pretrained=True) 10mobilenet = models.mobilenet_v2(pretrained=True) 11resnext50_32x4d = models.resnext50_32x4d(pretrained=True) 12wide_resnet50_2 = models.wide_resnet50_2(pretrained=True) 13mnasnet = models.mnasnet1_0(pretrained=True)

这里,我们只需要执行resnet18 = models.resnet18(pretrained=True)就可以下载resnet18的模型。这些模型的输入格式要求如下:

大小都是224x224,

RGB三通道图像,
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

下载与转为为ONNX的代码如下:

model = torchvision.models.resnet18(pretrained=True).eval() dummy_input = torch.randn((1, 3, 224, 224)) torch.onnx.export(model, dummy_input, "resnet18.onnx")

02

转为IR格式

Cmd至打开安装好的OpenVINO:

deployment_tools\model_optimizer

目录下,执行下面的命令行语句:

python mo_onnx.py --input_model D:\python\pytorch_tutorial\resnet18.onnx

9fb72472184ca79c317b784854f365ad.png

可以看到resnet18模型已经成功转好!

03

OpenVINO SDK调用

对转换好的IR模型,就可以首先通过OpenVINO202R3的Python版本SDK完成加速推理预测,完整的代码实现如下:

from __future__ import print_function import cv2 import numpy as np import logging as log from openvino.inference_engine import IECore with open('imagenet_classes.txt') as f:     labels = [line.strip() for line in f.readlines()] def image_classification():     model_xml = "resnet18.xml"     model_bin = "resnet18.bin"     # Plugin initialization for specified device and load extensions library if specified     log.info("Creating Inference Engine")     ie = IECore()     # Read IR     log.info("Loading network files:\n\t{}\n\t{}".format(model_xml, model_bin))     net = ie.read_network(model=model_xml, weights=model_bin)     log.info("Preparing input blobs")     input_blob = next(iter(net.inputs))     out_blob = next(iter(net.outputs))     # Read and pre-process input images     n, c, h, w = net.inputs[input_blob].shape     images = np.ndarray(shape=(n, c, h, w))     src = cv2.imread("D:/images/messi.jpg")     image = cv2.resize(src, (w, h))     image = np.float32(image) / 255.0     image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))     image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))     image = image.transpose((2, 0, 1))     # Loading model to the plugin     log.info("Loading model to the plugin")     exec_net = ie.load_network(network=net, device_name="CPU")     # Start sync inference     log.info("Starting inference in synchronous mode")     res = exec_net.infer(inputs={input_blob: [image]})     # Processing output blob     log.info("Processing output blob")     res = res[out_blob]     label_index = np.argmax(res, 1)     label_txt = labels[label_index[0]]     cv2.putText(src, label_txt, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 0, 255), 2, 8)     cv2.imshow("ResNet18-from Pytorch image classification", src)     cv2.waitKey(0)     cv2.destroyAllWindows() if __name__ == '__main__':     image_classification()

运行结果如下:

f2413d6fe095e2404bf06ca3d5882334.png

善始者实繁

克终者盖寡

 推荐阅读 

OpenCV4系统化学习路线图-视频版本!

OpenCV单应性矩阵发现参数估算方法详解

单应性矩阵应用-基于特征的图像拼接

OpenCV图像拼接改进算法之完美拼接

OpenCV | 二值图像分析的技巧都在这里

OpenCV二值图像分析之形态学应用技巧

图像色彩空间与应用转换

五分钟学会C++高效图表绘制神器调用

没想到图像直方图有这么多应用场景

基于灰度共生矩阵(GLCM)的图像纹理分析与提取

OpenCV中一个最容易搞错的形态学操作

OpenCV实现皮肤表面粗糙度3D显示

解密 | OpenCV加载图像大小是有限制的 ?

OpenCV中ORB特征提取与匹配

OpenCV SIFT特征算法详解与使用

HOG特征详解与行人检测

8b81d6486fd2dd7b6a02686478c8af45.png

转载地址:http://ujjg.baihongyu.com/

你可能感兴趣的文章
Nginx(2):Nginx配置server节点
查看>>
nginx:/usr/src/fastdfs-nginx-module/src/common.c:21:25:致命错误:fdfs_define.h:没有那个文件或目录 #include
查看>>
Nginx:NginxConfig可视化配置工具安装
查看>>
Nginx:现代Web服务器的瑞士军刀 | 文章末尾送典藏书籍
查看>>
ngModelController
查看>>
ngrok | 内网穿透,支持 HTTPS、国内访问、静态域名
查看>>
ngrok内网穿透可以实现资源共享吗?快解析更加简洁
查看>>
ngrok内网穿透可以实现资源共享吗?快解析更加简洁
查看>>
NHibernate动态添加表
查看>>
NHibernate学习[1]
查看>>
NHibernate异常:No persister for的解决办法
查看>>
Nhibernate的第一个实例
查看>>
NHibernate示例
查看>>
nid修改oracle11gR2数据库名
查看>>
NIFI1.21.0/NIFI1.22.0/NIFI1.24.0/NIFI1.26.0_2024-06-11最新版本安装_采用HTTP方式_搭建集群_实际操作---大数据之Nifi工作笔记0050
查看>>
NIFI1.21.0_java.net.SocketException:_Too many open files 打开的文件太多_实际操作---大数据之Nifi工作笔记0051
查看>>
NIFI1.21.0_Mysql到Mysql增量CDC同步中_日期类型_以及null数据同步处理补充---大数据之Nifi工作笔记0057
查看>>
NIFI1.21.0_Mysql到Mysql增量CDC同步中_补充_插入时如果目标表中已存在该数据则自动改为更新数据_Postgresql_Hbase也适用---大数据之Nifi工作笔记0058
查看>>
NIFI1.21.0_Mysql到Mysql增量CDC同步中_补充_更新时如果目标表中不存在记录就改为插入数据_Postgresql_Hbase也适用---大数据之Nifi工作笔记0059
查看>>
NIFI1.21.0_NIFI和hadoop蹦了_200G集群磁盘又满了_Jps看不到进程了_Unable to write in /tmp. Aborting----大数据之Nifi工作笔记0052
查看>>