python pth文件是什么_嫌python慢?来这里用pytorch C++前端推理模型

news/2024/7/4 9:08:59

e2a6f00e66afcb1479385ebc89ef3a87.gif

本文面向对象是立足于工业应用领域的AI学习者,希望学以致用的同学们可以深入看这篇文章,我一直认为学而不用就是搞八股文,现实中很多模型很fancy指标很溜但实际上是根本用不了的.

很多时候python用来训练模型很方便的,但是用来推理就显得捉襟见肘了. 其实在我看来,慢不是python的问题,各种乱七八糟的脚本才是最大的烦恼之处。相比之下,我更希望我们的人工智能网络模型能够以一个二进制文件的形式运行。不能因为乱七八糟的文件缺失而影响我们装逼。

本文主题是:使用libtorch C++前端来推理我们的复杂模型(请注意,不是简单的分类,分类我们已经分烂了),我们将推理的模型是比较复杂的人体姿态检测。这是一个能够同时实现CPU和GPU(不用说) realtime的方法。为什么我要强调同时实时, 因为有时候我们买不起显卡,但又想用CPU跑起来装X,这个很有用。

ce5baca73e91dc7e8a3fd6894783eac7.gif

ef4a14f493d0f4f3ab375ef759bb558c.png


效果还是不错的,GPU的帧率测试中,包含了pose提取等操作,大概有50fps,已经足够快了。比Python速度提高约29.1%,效率提升明显。

Pytorch C++ API

本文主要讲一些实操的部分,关于这个方法的论文,大家可以看一下Danill的论文,我们的实现libtorch C++ demo很大程度上得到了Danill的帮助 (intel大佬)。论文链接here。

使用pytorch c++ API分为三步:

  • 下载libtorch(废话)
  • trace你python训练好的模型
  • C++ load

我们看一下一个复杂的模型是如何load的:

model_path 

这里其实有几个坑,实现跟大家说一下:

  • 你的模型里面不要有print之类的东西,否则会报warning,不过不影响模型生成;
  • libtorch在处理模型的输出时,如果输出有多个tensor,最好的方式将tensor concat起来,否则你可能会在输出遇到问题。

从代码的角度来看一下关键的操作:

if 

这是最基本的module导入方式,libtorch通过JIT来load模型,在tracemodel的时候生成的pt模型其实是可以解压的文件,如果你仔细看里面的文件会发现里面也不过是包含了网络的结构以及每个变量的权重,相对来说逻辑还是比较显而易见的。

HumanPose人体关键点C++ Demo

采用libtorch导入模型预测是整个过程最简单的部分,最复杂的反而是用C++将人体关键点提取出来。这篇教程处理的模型训练所采用的数据来自于coco,自然用的也是coco的标注方式:

62724e301297685dce9ba8a6a249de84.png


image.png

如图所示,每一个姿态都包含17个点,并且每个点都有相应的顺序。
其实在我们写C++ 接口的时候已经将一些骨骼姿态render相关代码封装到了这个库thor中,直接一个函数就可以将所有的Pose画在图片上。

目前开源的人体姿态检测数据集用的比较多的就两个,一个是coco keypoints,另外一个是AIChallenge的数据集,但后者的标注与coco不一样,而且是15个点而非17个点。

最后,人体姿态检测所采用的方法现在通常是基于热力图的方法做的,基本的处理步骤可以分为:

  • 生成热力图;
  • 生成paf;
  • 从热力图得到keypoints和每个点的score;
  • 从paf和keypoints得到group之后的结果,最后可以合成对应的结果。

c2f29f96cfad2a230799bed025f74c1a.png

上图展示的是网络最终输出的热力图的可视化效果图。理论上来讲,目前基于热力图的方式预测人体姿态存在两个缺点:

  • 如果网络后端复杂速度将会很慢;
  • 占用的现存会比较大。

但是不可否认,随着CenterNet等方法开始占据目标检测SOTA,热力图的方式将会演变出越来越多的使用场景。

而通过C++ 前端来做推理无疑可以加速复杂模型的工业部署,对于我们实现ALL in AI理想又更近了一步.
那么问题来了, 很多对工业环境不太了解的人经常会问这玩意能做啥? 我python预测不挺好的吗? 从工业角度来分析一下用C++部署网络模型有啥好处:

  • 部署简单, 你只需要一个二进制文件执行;
  • 跨平台, 甚至通过一个动态链接库和NDK你可以轻而易举的在Android部署, 当然CoreML你可以直接转,但是得自己用swift写后处理代码;
  • 代码难度极高,一般人看不懂,适合装X.

整个项目代码已经开源至MANA平台:

神力AI(MANA)-国内最大的AI代码平台​manaai.cn
d698dc2bd6aeca06e291304ed82986d3.png


. 当然我们后续也会把项目merge到Danill的repo中.

最后打一个小广告, 国内最大的AI交流社区 MANA社区招募早期版主了哦! 使用github账号即可快速登录, 从此一个账号问遍所有AI问题! 并且可以与所有用户交流你的奇异想法!

ManaAI社区​talk.strangeai.pro
bacb188593843bfe566a64094d5ad21b.png

http://www.niftyadmin.cn/n/2747051.html

相关文章

三分钟入门Redux(Redux教程)

学习背景: 我最近在更新师兄之前用React写的项目,该项目中各组件的状态依赖关系非常复杂,为了便于管理组件的状态,师兄使用了Redux。我最近刚转React,此前没有用过Redux,借这个难得的机会,我学习…

字符串的切片操作与连接_Python中14个切片操作,你常用哪几个?

切片(Slice)是一个取部分元素的操作,是Python中特有的功能。它可以操作list、tuple、字符串。Python的切片非常灵活,一行代码就可以实现很多行循环才能完成的操作。切片操作的三个参数 [start: stop: step] ,其中start…

spring中的统一异常处理

在具体的SSM项目开发中,由于Controller层为处于请求处理的最顶层,再往上就是框架代码的。因此,肯定需要在Controller捕获所有异常,并且做适当处理,返回给前端一个友好的错误码。 不过,Controller一多&#…

arraylist从大到小排序_十大经典排序算法最强总结(内含代码实现),建议收藏!...

点击上方“Java之间”,选择“置顶或者星标”你关注的就是我关心的!来源:cnblogs.com/cndarren/p/11787368.html上一篇:ArrayList集合为什么不能使用foreach增删改01 算法分类 十种常见排序算法可以分为两大类:比较类…

ret和retf

ret指令用栈中的数据,修改IP的内容,从而实现近转移; retf指令用栈中的数据,修改CS和IP的内容,从而实现远转移。 CPU执行ret指令时,进行下面两步操作: (IP) ((ss)*16(sp))(sp)(sp)2CPU执行retf指令时,进行下…

opengl实现经纹理映射的旋转立方体_立方体纹理

立方体纹理就是包含6个2D纹理的纹理.6个纹理有序排列在立方体的6个面.其可以通过方向向量采样立方体纹理上的纹素.创建立方体贴图跟创建2D贴图一样,但是绑定到GL_TEXTURE_CUBE_MAP上.glGenTextures(1, &CubeMapID); glBindTexture(GL_TEXTURE_CUBE_MAP, CubeMapID);立方体纹…

rocketmq广播消息为什么不能重试_RocketMQ系列(五)广播与延迟消息

今天要给大家介绍RocketMQ中的两个功能,一个是“广播”,这个功能是比较基础的,几乎所有的mq产品都是支持这个功能的;另外一个是“延迟消费”,这个应该算是RocketMQ的特色功能之一了吧。接下来,我们就分别看…

数据结构 二叉树 根据后序和中序遍历输出先序遍历

根据后序和中序遍历输出先序遍历 题目描述: 本题要求根据给定的一棵二叉树的后序遍历和中序遍历结果,输出该树的先序遍历结果。 输入格式: 第一行给出正整数N(≤30),是树中结点的个数。随后两行,每行给出N个整数,分别…