SIMD 指令集友好的子串匹配算法

概述

标题简单明了,这是一篇看了之后也没什么用的文章。

这是我在看 Building even faster interpreters in Rust 时发现的,至于我为什么要看,可能是因为我也开发了一个防火墙模块

什么是 SIMD?

SIMD(Single Instruction Multiple Data)即单指令流多数据流,翻译成人话就是用一条指令操作多组数据。举个例子就是通常一条指令只能完成一对寄存器间的一次加法,现在我们可以用一条指令完成多对寄存器间的多次加法。

常规子串匹配算法的缺点

常规的子串匹配算法基本是下列流程。

k = needle.length
for i = 0 to haystack.length - k - 1 {
    if haystack[i] == needle[0] {
        for j = 1 to k - 1 {
            if haystack[i + j] != needle[j] {
                break
            }
        }
        if j == k {
            return true
        }
    }
}
return false

原理很简单,就是一个一个字符地匹配,但是有下列缺点。

  • 对于现代处理器来说,对 8 bit、16 bit、32 bit 甚至 64 bit 进行一次运算的耗时是一样的。如果还这样逐字节比较就浪费了硬件资源。
  • 逐字节比较也要逐次访问内存,浪费 CPU 周期。
  • 错误的分支预测则会浪费更多的 CPU 周期。
  • 这类代码难以乱序执行。

寄存器:可以认为是变量。

沃·兹基硕德

* There is no difference in comparing one, two, four or 8 bytes on a 64-bit CPU. When a processor supports SIMD instructions, then comparing vectors (it means 16, 32 or even 64 bytes) is as cheap as comparing a single byte.

* Thus comparing short sequences of chars can be faster than fancy algorithms which avoids such comparison.

* Looking up in a table costs one memory fetch, so at least a L1 cache round (~3 cycles). Reading char-by-char also cost as much cycles.

* Mispredicted jumps cost several cycles of penalty (~10-20 cycles).

* There is a short chain of dependencies: read char, compare it, conditionally jump, which make hard to utilize out-of-order execution capabilities present in a CPU.

SIMD-friendly algorithms for substring searching

优化算法

那么如何使用 SMID 来优化算法性能呢?

假设我们有一些 8-byte 的寄存器,我们要在字符串 “a_cat_tries” 中搜索子串 “cat”。

首先我们将 “cat” 的第一个字节和最后一个字节填充到两个寄存器中,并尽可能地重复知道寄存器被填满。

F    = [ c | c | c | c | c | c | c | c ]
L    = [ t | t | t | t | t | t | t | t ]

然后我们将字符串 “a_cat_tries” 加载到另外两个寄存器中,其中一个寄存器从第二个字符开始加载。

A    = [ a | _ | c | a | t | _ | t | r ]
B    = [ c | a | t | _ | t | r | i | e ]

然后比较两组寄存器的内容,对应位置的内容相同为 1,反之则为 0。

AF   = (A == F)
     = [ 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 ]

BL   = (B == L)
     = [ 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 ]

最后我们将两个寄存器的内容合并,即 “位与” 运算。

mask = [ 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 ]

此时我们发现只有第 2 位(从零开始)是 1,这说明只有从此处开始搜索才有可能搜索到子串。着就减少了我们的搜索次数,下面是一种实现。

size_t avx2_strstr_anysize(const char* s, size_t n, const char* needle, size_t k) {

    // 向寄存器中填充 needle 的第一个字节
    const __m256i first = _mm256_set1_epi8(needle[0]);
    // 向寄存器中填充 needle 的最后一个字节
    const __m256i last  = _mm256_set1_epi8(needle[k - 1]); 

    for (size_t i = 0; i < n; i += 32) {

        // 向寄存器中填充 s 的部分内容
        const __m256i block_first = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(s + i));

        // 向寄存器中填充 s 的部分内容,相对于上一行,本次填充的内容有所偏移
        const __m256i block_last  = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(s + i + k - 1));
    

        // 比较两组寄存器
        const __m256i eq_first = _mm256_cmpeq_epi8(first, block_first);
        const __m256i eq_last  = _mm256_cmpeq_epi8(last, block_last);

        // 合并两个寄存器的比较结果
        uint32_t mask = _mm256_movemask_epi8(_mm256_and_si256(eq_first, eq_last));

        while (mask != 0) {

            // 找到第一个值为 1 的 bit 的下标
            const auto bitpos = bits::get_first_bit_set(mask);

            if (memcmp(s + i + bitpos + 1, needle + 1, k - 2) == 0) {
                return i + bitpos;
            }

            mask = bits::clear_leftmost_set(mask);
        }
    }

    return std::string::npos;
}

参考资料

本文作者:ADD-SP
本文链接https://www.addesp.com/archives/5486
版权声明:本博客所有文章除特别声明外,均默认采用 CC-BY-NC-SA 4.0 许可协议。
暂无评论

发送评论 编辑评论


				
上一篇