此原理对线性多分类和非线性多分类都适用。

多分类过程

我们在此以具有两个特征值的三分类举例。可以扩展到更多的分类或任意特征值,比如在ImageNet的图像分类任务中,最后一层全连接层输出给分类器的特征值有成千上万个,分类有1000个。

  1. 线性计算

z1 = x_1 w{11} + x2 w{21} + b_1 \tag{1}

z2 = x_1 w{12} + x2 w{22} + b_2 \tag{2}

z3 = x_1 w{13} + x2 w{23} + b_3 \tag{3}

  1. 2. 分类计算

a_1=\frac{e^{z_1}}{\sum_i e{z_1}}{e{z_2}+e^{z_3}} \tag{4}

a_2=\frac{e^{z_2}}{\sum_i e{z_2}}{e{z_2}+e^{z_3}} \tag{5}

a_3=\frac{e^{z_3}}{\sum_i e{z_3}}{e{z_2}+e^{z_3}} \tag{6}

  1. 3. 损失函数计算

单样本时,n表示类别数,j表示类别序号:

\begin{aligned} loss(w,b)&=-(y1 \ln a_1 + y_2 \ln a_2 + y_3 \ln a_3) \ =-\sum{j=1}^{n} y_j \ln a_j \end{aligned} \tag{7}

批量样本时,m表示样本数,i表示样本序号:

\begin{aligned} J(w,b) &=- \sum{i=1}^m (y{i1} \ln a{i1} + y{i2} \ln a{i2} + y{i3} \ln a{i3}) \ =- \sum{i=1}^m \sum{j=1}^n y{ij} \ln a_{ij} \end{aligned} \tag{8}

损失函数计算在交叉熵函数一节有详细介绍。

数值计算举例

假设对预测一个样本的计算得到的z值为:

$$
z=[z_1,z_2,z_3]=[3,1,-3]
$$

则按公式4、5、6进行计算,可以得出Softmax的概率分布是:

$$
a=[a_1,a_2,a_3]=[0.879,0.119,0.002]
$$

如果标签值表明此样本为第一类

即:

y=[1,0,0]

则损失函数为:

loss_1=-(1 \times \ln 0.879 + 0 \times \ln 0.119 + 0 \times \ln 0.002)=0.123

反向传播误差矩阵为:

a-y=[-0.121,0.119,0.002]

因为a_1=0.879,为三者最大,分类正确,所以a-y的三个值都不大。

如果标签值表明此样本为第二类

即:

y=[0,1,0]

则损失函数为:

loss_2=-(0 \times \ln 0.879 + 1 \times \ln 0.119 + 0 \times \ln 0.002)=2.128

可以看到由于分类错误,loss_2的值比loss_1的值大很多。

反向传播误差矩阵为:

a-y=[0.879,0.881,0.002]

本来是第二类,误判为第一类,所以前两个元素的值很大,反向传播的力度就大。

多分类的几何原理

在前面的二分类原理中,很容易理解为我们用一条直线分开两个部分。对于多分类问题,是否可以沿用二分类原理中的几何解释呢?答案是肯定的,只不过需要单独判定每一个类别。

假设一共有三类样本,蓝色为1,红色为2,绿色为3,那么Softmax的形式应该是:

$$
a_j = \frac{e3 e{z_j}}{e{z_2}+^{z_3}}
$$

当样本属于第一类时

把蓝色点与其它颜色的点分开。

如果判定一个点属于第一类,则a_1的概率值一定会比a_2、a_3大,表示为公式:

a_1 > a_2 且 a_1 > a_3 \tag{9}

由于Softmax的特殊形式,分母都一样,所以只比较分子就行了。而分子是一个自然指数,输出值域大于零且单调递增,所以只比较指数就可以了,因此,公式9等同于下式:

z_1 > z_2 且 z_1 > z_3 \tag{10}

把公式1、2、3引入到10:

x1 w{11} + x2 w{21} + b1 > x1 w{12} + x2 w{22} + b_2 \tag{11}

x1 w{11} + x2 w{21} + b1 > x1 w{13} + x2 w{23} + b_3 \tag{12}

变形:

(w{21} - w{22})x2 > (w{12} - w_{11})x_1 + (b_2 - b_1) \tag{13}

(w{21} - w{23})x2 > (w{13} - w_{11})x_1 + (b_3 - b_1) \tag{14}

我们先假设:

w{21} > w{22},且 w{21}> w{23} \tag{15}

所以公式13、14左侧的系数都大于零,两边同时除以系数:

x2 > {w{12} - w{11} \over w{21} - w{22}}x_1 + {b_2 - b_1 \over w{21} - w_{22}} \tag{16}

x2 > {w{13} - w{11} \over w{21} - w{23}} x_1 + {b_3 - b_1 \over w{21} - w_{23}} \tag{17}

简化:

y > W{12} \cdot x + B{12} \tag{18}

y > W{13} \cdot x + B{13} \tag{19}

此时y代表了第一类的蓝色点。

线性多分类原理 - 图1

借用二分类中的概念,公式18的几何含义是:有一条直线可以分开第一类(蓝色点)和第二类(红色点),使得所有蓝色点都在直线的上方,所有的红色点都在直线的下方。于是我们可以画出图7-9中的那条绿色直线

而公式19的几何含义是:有一条直线可以分开第一类(蓝色点)和第三类(绿色点),使得所有蓝色点都在直线的上方,所有的绿色点都在直线的下方。于是我们可以画出图7-9中的那条红色直线。

也就是说在图中画两条直线,所有蓝点都同时在红线和绿线这两条直线的上方。

当样本属于第二类时

即如何把红色点与其它两色点分开。

z_2 > z_1 且 z_2 > z_3 \tag{20}

同理可得

y < W{12} \cdot x + B{12} \tag{21}

y > W{23} \cdot x + B{23} \tag{22}

线性多分类原理 - 图2

此时y代表了第二类的红色点。

公式21和公式18几何含义相同,不等号相反,代表了图7-10中绿色直线的分割作用,即红色点在绿色直线下方。

公式22的几何含义是,有一条蓝色直线可以分开第二类(红色点)和第三类(绿色点),使得所有红色点都在直线的上方,所有的绿色点都在直线的下方。

当样本属于第三类时

即如何把绿色点与其它两色点分开。

z_3 > z_1 且 z_3 > z_2 \tag{22}

最后可得:

y < W{13} \cdot x + B{13} \tag{23}

y < W{23} \cdot x + B{23} \tag{24}

此时y代表了第三类的绿色点。

线性多分类原理 - 图3

公式23与公式19不等号相反,几何含义相同,代表了图7-11中红色直线的分割作用,绿色点在红色直线下方。

公式24与公式22不等号相反,几何含义相同,代表了图7-11中蓝色直线的分割作用,绿色点在蓝色直线下方。

综合效果

把三张图综合在一起,应该是图7-12的样子。

线性多分类原理 - 图4