利用AVX2指令集加速推荐系统MMR层余弦相似度计算

商道如狼道  论坛元老 | 2024-10-11 11:26:58 | 来自手机 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 1038|帖子 1038|积分 3114

原文:blog.fanscore.cn/a/62/
1. 配景

前一段时间公司上线了一套Go实现的推荐系统,上线后发现MMR层固然只有纯计算但耗时十分离谱,通过pprof定位问题所在之后进行了优化,固然降低了非常多但是我们以为其中还有优化空间。

可以看到一样平常平均耗时126ms,P95 360ms。
MMR层主要耗时会合在了余弦相似度的计算部分,这部分我们利用的gonum库进行计算,其底层在x86平台上利用了SSE指令集进行了加速。
SSE指令集已经非常古老了,xmm寄存器只能存储两个双精度浮点数,每次只能并行进行两个双精度浮点数的计算,而AVX2指令集可以并行计算四个,理论上可以得到两倍的性能提升,因此我们决定本身利用AVX2指令集手写汇编的方式替代掉gonum库。
1.1 余弦相似度算法

余弦相似度的计算公式为

对应的代码为
  1. import "gonum.org/v1/gonum/floats"
  2. func CosineSimilarity(a, b []float64) float64 {
  3.     dotProduct := floats.Dot(a, b) // 计算a和b的点积
  4.     normA := floats.Norm(a, 2) // 计算向量a的L2范数
  5.     normB := floats.Norm(b, 2) // 计算向量b的L2范数
  6.     return dotProduct / (normA * normB)
  7. }
复制代码
2. Dot点积计算加速

gonum点积计算Dot的部分汇编代码如下:
  1. TEXT ·DotUnitary(SB), NOSPLIT, $0
  2.     ...
  3. loop_uni:
  4.         // sum += x[i] * y[i] unrolled 4x.
  5.         MOVUPD 0(R8)(SI*8), X0
  6.         MOVUPD 0(R9)(SI*8), X1
  7.         MOVUPD 16(R8)(SI*8), X2
  8.         MOVUPD 16(R9)(SI*8), X3
  9.         MULPD  X1, X0
  10.         MULPD  X3, X2
  11.         ADDPD  X0, X7
  12.         ADDPD  X2, X8
  13.         ADDQ $4, SI   // i += 4
  14.         SUBQ $4, DI   // n -= 4
  15.         JGE  loop_uni // if n >= 0 goto loop_uni
  16.     ...
  17. end_uni:
  18.         ADDPD    X8, X7
  19.         MOVSD    X7, X0
  20.         UNPCKHPD X7, X7
  21.         ADDSD    X0, X7
  22.         MOVSD    X7, sum+48(FP) // Return final sum.
  23.         RET
复制代码
可以看到其中利用xmm寄存器并行计算两个双精度浮点数,并且还采用了循环展开的优化本领,一个循环中同时进行4个元素的计算。
我们利用AVX2指令集并行计算四个双精度浮点数进行加速
  1. loop_uni:
  2.         // sum += x[i] * y[i] unrolled 8x.
  3.         VMOVUPD 0(R8)(SI*8), Y0 // Y0 = x[i:i+4]
  4.         VMOVUPD 0(R9)(SI*8), Y1 // Y1 = y[i:i+4]
  5.         VMOVUPD 32(R8)(SI*8), Y2 // Y2 = x[i+4:i+8]
  6.         VMOVUPD 32(R9)(SI*8), Y3 // Y3 = x[i+4:i+8]
  7.         VMOVUPD 64(R8)(SI*8), Y4 // Y4 = x[i+8:i+12]
  8.         VMOVUPD 64(R9)(SI*8), Y5 // Y5 = y[i+8:i+12]
  9.         VMOVUPD 96(R8)(SI*8), Y6 // Y6 = x[i+12:i+16]
  10.         VMOVUPD 96(R9)(SI*8), Y7 // Y7 = x[i+12:i+16]
  11.         VFMADD231PD Y0, Y1, Y8 // Y8 = Y0 * Y1 + Y8
  12.         VFMADD231PD Y2, Y3, Y9
  13.         VFMADD231PD Y4, Y5, Y10
  14.         VFMADD231PD Y6, Y7, Y11
  15.         ADDQ $16, SI   // i += 16
  16.         CMPQ DI, SI
  17.         JG  loop_uni // if len(x) > i goto loop_uni
复制代码
可以看到我们每个循环中同时用到8个ymm寄存器即一次循环计算16个数,而且还用到了VFMADD231PD指令同时进行乘法累积的计算。
最终Benchmark结果:
  1. BenchmarkDot 一个循环中计算8个数
  2. BenchmarkDot-2          14994770                78.85 ns/op
  3. BenchmarkDot16 一个循环中计算16个数
  4. BenchmarkDot16-2        22867993                53.46 ns/op
  5. BenchmarkGonumDot Gonum点积计算
  6. BenchmarkGonumDot-2      8264486               144.4 ns/op
复制代码
可以看到点积部分我们得到了约莫2.7倍的性能提升
3. L2范数计算加速

gonum库中进行L2范数计算的算法并不是通例的a1^2 + a2^2 ... + aN^2这种计算,而是采用了Netlib算法,淘汰了溢出和下溢,其Go源码如下:
  1. func L2NormUnitary(x []float64) (norm float64) {
  2.         var scale float64
  3.         sumSquares := 1.0
  4.         for _, v := range x {
  5.                 if v == 0 {
  6.                         continue
  7.                 }
  8.                 absxi := math.Abs(v)
  9.                 if math.IsNaN(absxi) {
  10.                         return math.NaN()
  11.                 }
  12.                 if scale < absxi {
  13.                         s := scale / absxi
  14.                         sumSquares = 1 + sumSquares*s*s
  15.                         scale = absxi
  16.                 } else {
  17.                         s := absxi / scale
  18.                         sumSquares += s * s
  19.                 }
  20.         }
  21.         if math.IsInf(scale, 1) {
  22.                 return math.Inf(1)
  23.         }
  24.         return scale * math.Sqrt(sumSquares)
  25. }
复制代码
其汇编代码比较晦涩难懂,但管中窥豹再联合Go源码可以看出来没有用到并行能力,每次循环只计算一个数
  1. TEXT ·L2NormUnitary(SB), NOSPLIT, $0
  2.     ...
  3. loop:
  4.         MOVSD   (X_)(IDX*8), ABSX // absxi = x[i]
  5.         ...
复制代码
我们优化之后的焦点代码如下:
  1. loop:
  2.         VMOVUPD 0(R8)(SI*8), Y0 // Y0 = x[i:i+4]
  3.         VMOVUPD 32(R8)(SI*8), Y1 // Y1 = y[i+4:i+8]
  4.         VMOVUPD 64(R8)(SI*8), Y2 // Y2 = x[i+8:i+12]
  5.         VMOVUPD 96(R8)(SI*8), Y3 // Y3 = x[i+12:i+16]
  6.         VMOVUPD 128(R8)(SI*8), Y4 // Y4 = x[i+16:i+20]
  7.         VMOVUPD 160(R8)(SI*8), Y5 // Y5 = y[i+20:i+24]
  8.         VMOVUPD 192(R8)(SI*8), Y6 // Y6 = x[i+24:i+28]
  9.         VMOVUPD 224(R8)(SI*8), Y7 // Y7 = x[i+28:i+32]
  10.         VFMADD231PD Y0, Y0, Y8 // Y8 = Y0 * Y0 + Y8
  11.         VFMADD231PD Y1, Y1, Y9
  12.         VFMADD231PD Y2, Y2, Y10
  13.         VFMADD231PD Y3, Y3, Y11
  14.         VFMADD231PD Y4, Y4, Y12
  15.         VFMADD231PD Y5, Y5, Y13
  16.         VFMADD231PD Y6, Y6, Y14
  17.         VFMADD231PD Y7, Y7, Y15
  18.         ADDQ $32, SI // i += 32
  19.         CMPQ DI, SI
  20.         JG  loop // if len(x) > i goto loop
复制代码
我们采用原始的算法计算以利用到并行计算的能力,并且循环展开,一次循环中同时计算32个数,最终Benchmark结果:
  1. BenchmarkAVX2L2Norm
  2. BenchmarkAVX2L2Norm-2          29381442                40.99 ns/op
  3. BenchmarkGonumL2Norm
  4. BenchmarkGonumL2Norm-2           1822386               659.4 ns/op
复制代码
可以看到得到了约莫16倍的性能提升
4. 总结

通过这次优化我们在余弦相似度计算部分最终得到了(144.4 + 659.4 * 2) / (53.46 + 40.99 * 2) = 10.8倍的性能提升,效果照旧非常显著的。相较于《记一次SIMD指令优化计算的失败经历》这次失败的初次实验,本次照旧非常成功的,切实感受到了SIMD的威力。
别的在本次优化过程中也涨了不少姿势
AVX-512指令降频问题

AVX-512指令因为并行度更高理论上性能也更高,但AVX-512指令会造成CPU降频,因此业界利用非常慎重,这一点可以参考字节的json剖析库sonic的这个issue: https://github.com/bytedance/sonic/issues/319
循环展开优化

在一次循环中做更多的工作,优点有很多:

  • 淘汰循环控制的开销,循环变量的更新和条件判断次数更少,降低了分支预测失败的可能性
  • 增加指令并行性,更多的指令可以在流水线中并行执行
但一次循环利用过多的寄存器从实际Benchmark看性能确实更好,但是否存在隐患我没有看到相干的资料,希望这方面的专家可以指教一下。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

商道如狼道

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表