传统的分类损失计算输入数据和每个类别中心的距离,来优化模型的训练。KNN softmax通过选择和输入数据最相关的top-K个类别,仅计算输入数据和top-K个类别中心的距离,以减小计算量。
KNN softmax首次诞生于达摩院机器智能技术实验室发表的SIGKDD 2020年《Large-Scale Training System for 100-Million Classification at Alibaba》
简单说下论文作者:
- Pan Pan:潘攀,拍立淘创始人,著有《深度学习图像搜索与识别》
- Liuyihan Song、Kang Zhao、Yiming Chen、Yingya Zhang均来自拍立淘团队
- Yinghui Xu:徐盈辉,徐盈辉-复旦大学人工智能创新与产业(AI³)研究院 (fudan.edu.cn)
- Rong Jin:金榕(阿里巴巴原副总裁、达摩副院长)_百度百科 (baidu.com)
问题建模
一个常见的图像分类任务整体流程如下:
输入图像 x i x_i xi送入Convolutional Feature Learning模块,提取图像表征 f x i ∈ R 1 × D f_{x_i}\in \mathbb{R}^{1\times D} fxi∈R1×D(其中 D D D表示维度),再通过Fully Connected Layer,将图像表征维度 f x i f_{x_i} fxi映射到类别数 C C C上,紧接着通过Softmax Function获取 [ 0 , 1 ] [0,1] [0,1]的概率值,计算分类损失。
我们来进行公式化定义,
(1)图像表征 f x i f_{x_i} fxi通过Fully Connected Layer将维度映射到类别数 C C C,可以建模成: f x i W ∈ R C f_{x_i}W \in \mathbb{R}^C fxiW∈RC,其中 W ∈ R D × C W \in \mathbb{R}^{D\times C} W∈RD×C。一般情况下,Fully Connected Layer会有偏置 b b b,将偏置 b b b设置为0。
(2)通过Softmax Function获取 [ 0 , 1 ] [0,1] [0,1]的概率值,得到 f x i W ∑ j e x p ( f x i W j ) \frac{f_{x_i}W}{\sum_j{exp(f_{x_i}W_j)}} ∑jexp(fxiWj)fxiW,其中 W j ∈ R D × 1 W_j \in \mathbb{R}^{D\times 1} Wj∈RD×1,表示第 j j j列数据,也指类别表征。
(3)分类损失的定义为: L = − log ( e x p ( f x i W y i ) ∑ j e x p ( f x i W j ) ) = − log ( e x p ( ∥ f x i ∥ ⋅ ∥ W y i ∥ ⋅ c o s ( θ y j ) ) ∑ j e x p ( ∥ f x i ∥ ⋅ ∥ W j ∥ ⋅ c o s ( θ j ) ) ) (1) \begin{equation}\begin{aligned} L&=-\log\left(\frac{exp(f_{x_i}W_{y_i})}{\sum_j{exp(f_{x_i}W_j)}}\right)\\ &=-\log\left(\frac{exp(\|f_{x_i}\|\cdot \|W_{y_i}\|\cdot cos(\theta_{y_j}))}{\sum_j{exp(\|f_{x_i}\|\cdot\|W_j\|\cdot cos(\theta_{j}))}}\right)\\ \end{aligned} \end{equation}\tag{1} L=−log(∑jexp(fxiWj)exp(fxiWyi))=−log(∑jexp(∥fxi∥⋅∥Wj∥⋅cos(θj))exp(∥fxi∥⋅∥Wyi∥⋅cos(θyj)))(1),其中 y i y_i yi指的是输入图像 x i x_i xi对应的类别下标,等式上下成立的原因是向量的内积公式 a ⋅ b = ∥ a ∥ ⋅ ∥ b ∥ ⋅ cos θ \mathbf{a} \cdot \mathbf{b} = \|\mathbf{a}\| \cdot \|\mathbf{b}\| \cdot \cos\theta a⋅b=∥a∥⋅∥b∥⋅cosθ。
(4)在常规实践中,图像表征
f
x
i
f_{x_i}
fxi和类别表征
W
j
W_j
Wj一般都事先归一化好,仅需要考虑两个表征间的余弦距离。同时,需要乘上一个缩放因子,用于控制训练的激进程度,例如
L
=
−
log
(
e
x
p
(
α
⋅
c
o
s
(
θ
y
j
)
)
∑
j
e
x
p
(
α
⋅
c
o
s
(
θ
j
)
)
)
=
−
log
(
e
x
p
(
α
⋅
f
x
i
∥
f
x
i
∥
⋅
W
y
i
∥
W
y
i
∥
)
∑
j
e
x
p
(
α
⋅
f
x
i
∥
f
x
i
∥
⋅
W
j
∥
W
j
∥
)
=
−
log
(
e
x
p
(
α
⋅
f
x
i
n
o
r
m
⋅
W
i
n
o
r
m
)
∑
j
e
x
p
(
α
⋅
f
x
i
n
o
r
m
⋅
W
j
n
o
r
m
)
(2)
\begin{equation}\begin{aligned} L&=-\log\left(\frac{exp(\alpha \cdot cos(\theta_{y_j}))}{\sum_j{exp(\alpha \cdot cos(\theta_{j}))}}\right)\\ &=-\log\left(\frac{exp(\alpha \cdot \frac{f_{x_i}}{\|f_{x_i}\|}\cdot \frac{W_{y_i}}{\|W_{y_i}\|})}{\sum_j{exp(\alpha \cdot \frac{f_{x_i}}{\|f_{x_i}\|}\cdot \frac{W_{j}}{\|W_{j}\|}}}\right)\\ &=-\log\left(\frac{exp(\alpha \cdot f_{x_i}^{norm} \cdot W_{_i}^{norm})}{\sum_j{exp(\alpha \cdot f_{x_i}^{norm}\cdot W_{j}^{norm}}}\right)\\ \end{aligned} \end{equation}\tag{2}
L=−log(∑jexp(α⋅cos(θj))exp(α⋅cos(θyj)))=−log
∑jexp(α⋅∥fxi∥fxi⋅∥Wj∥Wjexp(α⋅∥fxi∥fxi⋅∥Wyi∥Wyi)
=−log(∑jexp(α⋅fxinorm⋅Wjnormexp(α⋅fxinorm⋅Winorm))(2)
,这个就是CLIP用的损失函数的形式了。
KNN softmax
全连接层的模型并行
如果特征维度是512维,分类1个亿的全连接层参数有 512 × 100000000 = 5.12 ∗ 1 0 10 512\times 100000000=5.12*10^{10} 512×100000000=5.12∗1010。若参数存储形式为fp32,即1个参数需要4个字节,那么占用的显存为 5.12 × 1 0 10 ∗ 4 1024 × 1024 × 1024 = 191.1 G B \frac{5.12\times 10^{10}*4}{1024\times 1024\times 1024}=191.1GB 1024×1024×10245.12×1010∗4=191.1GB。
很显然,单块显卡装不下。于是,本文将全连接层参数均分到每一块显卡上。假设我们有256块V100显卡,每块显卡只需要装 191.1 G B 256 = 0.74 G B \frac{191.1 GB}{256}=0.74GB 256191.1GB=0.74GB,很显然,每块显卡的负担小得多了。
做法如上图所示,包括数据并行和模型并行。
- 数据并行指的是Convolutional Feature Learning模块参数复制到每块GPU上,只有数据均分成 N N N份,送入不同GPU中。
- 模型并行特指全连接层参数均分成
N
N
N份,存储到不同GPU中。
具体流程如下:
(1)数据均分成 N N N份,送到不同GPU中。
(2)每块GPU上,通过Convolutional Feature Learning模块提取图像表征,再执行all-gather操作,将不同GPU的表征汇聚到每一块GPU上。(假设有3块GPU,每块GPU提取了 R 2 × 512 \mathbb{R}^{2\times 512} R2×512表征,执行all-gather操作后,将3块GPU的表征汇聚起来,分发到所有GPU上,每块GPU提取的表征变为 R 6 × 512 \mathbb{R}^{6\times 512} R6×512)
(3)第 i i i块GPU将图像表征送到第 i i i份全连接层参数上
(4)执行分布式softmax计算,以及损失的计算
(5)每块GPU参数反向传播,在反向传播至Convolutional Feature Learning模块前,汇聚梯度,再进一步向前传播。
(6)参数更新时,第 i i i份全连接层参数仅通过第 i i i块GPU的梯度进行更新;Convolutional Feature Learning模块则通过全GPU的梯度进行更新。
尽管做了全连接层的模型并行,但是全连接层的计算量级实在太大,越80%的训练时间消耗在全连接层的操作上(全连接层前向传播,softmax前向传播,softmax反向传播,全连接层反向传播)
top-K类别选择
在公式(2)中,有 L = − log ( e x p ( α ⋅ f x i n o r m ⋅ W y i n o r m ) ∑ j e x p ( α ⋅ f x i n o r m ⋅ W j n o r m ) L=-\log\left(\frac{exp(\alpha \cdot f_{x_i}^{norm} \cdot W_{y_i}^{norm})}{\sum_j{exp(\alpha \cdot f_{x_i}^{norm}\cdot W_{j}^{norm}}}\right) L=−log(∑jexp(α⋅fxinorm⋅Wjnormexp(α⋅fxinorm⋅Wyinorm)),分类损失需要计算输入表征 f x i n o r m f_{x_i}^{norm} fxinorm和所有类别表征的余弦距离。由于类别数特别大,计算难度特别高,所以选择从中挑选 K K K个类别,进行分母的计算。
这是一个典型的检索场景,文中利用输入数据类别 y i y_i yi的类别表征 W y i W_{y_i} Wyi去检索所有类别中心表征,得到top-K个相似度最高的类别,用于分类损失的分母计算。
分布式KNN图构建
KNN图的建立可以理解为:给定query集合,以及doc集合,建立每个query到doc内最相近top-k个样本的关系。
在1亿类别分类场景,query和doc集合都等于1亿类别,建KNN图流程就特指:将每1个类别中心作为query,检索1亿个类别中心内,最相似的top-k个类别中心,构成 1 亿 × k 1亿\times k 1亿×k的相似度矩阵。
大规模检索场景常用的策略为ANN检索(Approximate Nearest Neighbor,近似最近邻检索)。但作者发现ANN对召回影响较大,导致损失偏差较大,效果不好,推荐采用暴力检索(brute-force)。
暴力检索不影响召回率,但很耗时,所以无法每个iteration更新一次,本文是每隔一个epoch更新一次KNN图。
因为模型并行,已经将全连接层均分到每块GPU上,建立KNN图是需要考虑该因素。传统的建图策略是:将所有GPU上的类别表征聚合到每块GPU上,得到完整的doc集合。计算每块GPU上的类别表征与完整doc集合的相似度矩阵,很显然,对显存消耗很高。
采用分布式建图,策略为:假设将GPU(id=0)作为query,计算KNN图,流程有:
- 在GPU(id=0)上,计算query到GPU(id=0)上类别表征的top-k,结果传播到GPU(id=1)上
- 在GPU(id=1)上,计算query到GPU(id=1)上类别表征的top-k,结果传播到GPU(id=2)上
- …
- 最后,将最终结果返回到GPU(id=0)上
这样的处理方式对显存消耗非常小,并且GPU间的通信量也少。
具体实现时,类别中心的存储由fp32改为fp16,并且采用TensorCore进行相似度计算加速(较原方法能加速3倍)。fp16的精度低于fp32,为平衡速度和效果,首先用fp16精度从全类别中心里搜top- k ′ k^{'} k′,再利用fp32精度从top- k ′ k^{'} k′中搜出top-k。
经过上述一通操作,1亿类别中心的KNN建图时间仅需0.75h。
采用和全连接层模型并行类似的策略,将KNN图按照query维度均分到每块GPU上,平均每块GPU仅需承担 372 G B / 256 = 1.45 G B 372GB/256=1.45GB 372GB/256=1.45GB,在可承受范围内。
效果比较
分别用了1百万类、1千万类、1亿类的数据进行训练,统计分类准确率和吞吐量,结果如下:
分类准确率:
- selective softmax:分母中通过Hashing Forest来选择k个类别,未采用KNN方式选择
- MACH:一种加速策略,速度快,但效果不好
- Full Softmax指的是分类损失中,分布用全类别表征计算得到
吞吐量:
,表明KNN Softmax能够有效提升吞吐量,类别越多,提升幅度越大。