- TensorRT Plugin 介绍
- TensorRT Plugin 的工作流程
- Static Shape Plugin API
- EmbLayerNormPlugin Static Shape Demo
- Dynamic Shape Plugin API
- EmbLayerNormPlugin Dynamic Shape Demo
- PluginCreator 注册
- TensorRT 如何 debug – Debug Plugin
TensorRT Plugin 介绍
Plugin 存在的意义:
- TRT支持的算子有限,实现不支持的算子;
- 进行深度优化-合并算子。
对于不支持的算子:groupnorm,gelu,split等,我们有两种方法解决:
- 写plugin插件
- 用低级算子来替代
对于复杂的网络,合并算子是非常有意义的。比如,可以将下方的代码合并为一个plugin,也就是
一个kernel,可以有效提高性能。
- 简单网络:规整网络,使用基础的网络API,比如conv、pool、softmax、relu等
- 复杂网络:DPRNN、DCN等;在网络训练中,要对数据进行处理,然后在继续运行。
因为这4个kernel都涉及的是数据操作,是访存密集型的。因此,可以将4个kernel合并成一个kernel,4次访存4次写入变为1次访存1次写入,速度会大幅度提升。
官方github给出了很多plugin demo,大都跟计算机视觉和BERT模型相关。
TensorRT Plugin的工作流程
Static Shape Plugin API
Dynamic Shape:输入维度是动态的;
Static Shape:输入维度是定死的
IPluginV2IOExt( / IPluginV2DynamicExt:插件类,用于写插件的具体实现;- IPluginCreator:插件工厂类,用于根据需求创建该插件。
注意:
IPluginV2Ext(不用,除非trt5.0以下)- 编写plugin,需要继承TRT的base class(不同的base class特性如上表)
- Static Shape 用
IPluginV2IOExt
;Dynamic Shape,则使用IPluginV2DynamicExt
MyCustomPlugin(int in_channel, nvinfer1::Weights const& weight, nvinfer1::Weights const& bias);
MyCustomPlugin(void const* serialData, size_t serialLength);
int getNbOutputs() const;
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims);
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const;
size_t getSerializationSize() const;
void serialize(void* buffer) const;
const char* getPluginType() const;
const char* getPluginVersion() const;
int initialize();
void terminate();
void destroy();
void configurePlugin(const nvinfer1::PluginTensorDesc* in, int nbInput, const
nvinfer1::PluginTensorDesc* out, int nbOutput);
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int
nbOutputs) const;
size_t getWorkspaceSize(int maxBatchSize) const;
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t
stream);
构造函数
用于network definition阶段,PluginCreator创建该插件时调用的构造函数,需要传递权重信息以及参数。也可用于clone阶段,或者再写一个clone构造函数。
MyCustomPlugin(int in_channel, nvinfer1::Weights const& weight, nvinfer1::Weights const& bias);
用于在deserialize阶段,用于将序列化好的权重和参数传入该plugin并创建。
MyCustomPlugin(void const* serialData, size_t serialLength);
注意需要把默认构造函数删掉:
MyCustomPlugin() = delete;
析构函数
析构函数则需要执行terminate,terminate函数就是释放这个op之前开辟的一些显存空间:
MyCustomPlugin::~MyCustomPlugin() { terminate(); }
输出相关函数
获得layer的输出个数
int getNbOutputs() const;
根据输入个数和输入维度,获得第index个输出的维度
nvinfer1::Dims getOutputDimensions(int index, const nvinfer1::Dims* inputs, int nbInputDims);
根据输入个数和输入类型,获得第index个输出的类型
nvinfer1::DataType getOutputDataType(int index, const nvinfer1::DataType* inputTypes, int nbInputs) const;
序列化和反序列化相关函数
- 返回序列化时需要写多少字节到buffer中
size_t getSerializationSize() const;
序列化函数,将plugin的参数权值写入到buffer中
void serialize(void* buffer) const;
初始化函数,在这个插件准备开始run之前执行。一般申请权值显存空间并copy权值
int initialize();
terminate函数就是释放initialize开辟的一些显存空间
void terminate();
释放整个plugin占用的资源
void destroy();
配置这个插件op,判断输入和输出类型数量是否正确
void configurePlugin(const nvinfer1::PluginTensorDesc* in, int nbInput, const nvinfer1::PluginTensorDesc* out, int nbOutput);
判断pos索引的输入/输出是否支持inOut[pos].format和inOut[pos].type指定的格式/数据类型
bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) const;
运行相关函数
获得plugin所需要的显存大小。最好不要在plugin enqueue中使用cudaMalloc申请显存。
size_t getWorkspaceSize(int maxBatchSize) const;
inference函数
int enqueue(int batchSize, const void* const* inputs, void** outputs, void* workspace, cudaStream_t stream);
IPluginCreator 相关函数
获得pluginname和version,用于辨识creator
const char* getPluginName() const; const char* getPluginVersion() const;
通过PluginFieldCollection去创建plugin,将op需要的权重和参数一个一个取出来,然后调用上文提到的第一个构造函数:
const nvinfer1::PluginFieldCollection* getFieldNames(); nvinfer1::IPluginV2* createPlugin(const char* name, const nvinfer1::PluginFieldCollection* fc);
反序列化,调用反序列化那个构造函数,生成plugin
nvinfer1::IPluginV2* deserializePlugin(const char* name, const void* serialData, size_t serialLength);
EmbLayerNormPlugin Static Shape Demo
:::info EmbLayerNormPlugin 是 BERT 模型Embedding + Layernorm的合并 ::: BERT 的 EmbLayerNormPlugin 层,主要有以下5个参数:
三个 Embedding 参数矩阵,分别是语义的 Embedding,位置的 Embedding, token type 的 Embedding。
- Embedding 操作除上面3个 embedding 做对应位置的求和,同时还要过一个 LayerNorm 操作,即对Embedding 方向的维度做一个归一化,所以还需要LayerNorm 的 beta 和 gamma 参数。