深入理解 KMP 算法

 

字符串匹配是非常常用的算法. 它要做的是, 对于模式串 t, 找到它在目标串 s 中第一个出现的位置. 例如 t = "abcab", s = "ababcabd", 算法就应该返回 2. 说起字符串匹配, 不得不提到 KMP 算法. KMP 是一个很厉害的算法, 它非常巧妙, 能在线性时间复杂度内完成字符串匹配. 本文我们来了解它.

本文会用到 Python 的切片语法表示子串. 如果你不熟悉 Python, 只需要记住: s[:k] 表示 s 长度为 k 的前缀, s[-k:] 表示 s 长度为 k 的后缀. 此外字符串的下标从 0 开始, 因此 s[:k] 的最后一个字符为 s[k-1].

匹配算法

粗略地说, KMP 算法利用了字符串的一个性质: 字符串的尾部有可能与自己的首部相匹配. 比如说字符串 abcab, 它的尾部有两个字符可以与自己的首部相匹配:

   abcab
   ||
abcab

用形式化的语言描述: 对于一个字符串 t, 存在一个最大的数 0 <= m < len(t), 使得 t[:m] == t[-m:] 成立. 我们称 m最大首尾匹配长度. 注意 m 必须小于 t 的长度, 不然没有意义 – 任何字符串总是与自身相等.

KMP 首先会为模式串的所有前缀求得最大首尾匹配长度, 存储在数组 pi 中. 前缀 t[:i] 的最大首尾匹配长度为 pi[i-1]. 例如字符串 p = "abcab"pi 数组为 [0, 0, 0, 1, 2], 因为:

  • p[:1] = "a", 最大首尾匹配长度为 0, pi[0] = 0
  • p[:2] = "ab", 最大首尾匹配长度为 0, pi[1] = 0
  • p[:3] = "abc", 最大首尾匹配长度为 0, pi[2] = 0
  • p[:4] = "abca", 最大首尾匹配长度为 1, pi[0] = 1, 因为有
         abca
         |
      abca
    
  • p[:5] = "abcab", 最大首尾匹配长度为 2, pi[0] = 2, 因为
         abcab
         ||
      abcab
    

我们先不管 KMP 是怎样求 pi 数组的, 先看看 KMP 是怎样利用 pi 数组做匹配的.

假设字符串 s 与模式串 t 匹配. 我们从第 0 个字符开始匹配, 接着是第 1 个, 第 2 个, 一直到第 k 个字符, 都匹配成功了; 然而第 k + 1 个字符匹配失败了.

kmp_1

匹配失败了, 那怎么办呢? 重点来了. 第 0 至 k 个字符匹配成功了, 这段子串等于 t[:k+1]. pi 数组告诉我们, 这段子串的后 pi[k] 个字符正好等于它的前 pi[k] 个字符.

kmp_2

这样下一步我们就让 s[k+1]t[pi[k]] 相比较, 如果相等, 就继续匹配后面的字符; 如果不相等 – 没关系, t[:pi[k]] 已经匹配成功, 我们再去查询 pi 数组得到这段子串的最大首尾匹配长度, 再按照同样的方式比较相应的字符即可.

有了这个思路, 我们就不难写出 KMP 算法的代码了:

def kmp(s, t):
    pi = calc_pi(t) # 先不管如何计算 pi
    j = 0
    for i, c in enumerate(s):
        while j > 0 and t[j] != c: # t[j] 匹配失败, 但是 t[:j] 匹配成功
            j = pi[j-1] # t[:j] 的后 pi[j-1] 个字符与前 pi[j-1] 个字符相等, 下一步匹配 t[pi[j-1]]
        if t[j] == c:
            j += 1
        if j == len(t):
            return i - j + 1 # 返回起始位置

    return -1

求 pi 数组

既然 pi 数组这么好用, 怎么求它呢? 首先很容易想到 pi[0] = 0, 因为最大首尾匹配长度需小于字符串长度. 如果我们能够用 pi[0], pi[1], ..., pi[k] 推出 pi[k+1], 我们就能求出整个 pi 数组了.

假设我们求出了 pi[k], 也就是 t[:k+1] 的最大首尾匹配长度.

kmp_3

那么 pi[k+1] 也就是 t[:k+2] 的最大首位匹配长度是多少呢? 这需要分两种情况讨论. 假设 t[k+1] == t[pi[k]], 这种情况很简单, pi[k+1] = pi[k] + 1.

kmp_4

要是不相等呢? 没关系, 前面那几个字符, 也就是 t[:pi[k]], 不是匹配上了么? 根据前面刚求出来的 pi 数组, 我们知道对于 t[:pi[k]] 这个子串, 后 pi[pi[k]-1] 个字符与前 pi[pi[k]-1] 个字符相等! 这就回到前面匹配算法的情况了.

kmp_5

接下来我们只需要让 t[k+1]t[pi[pi[k]-1]] 相比较. 如果相等, 那么 pi[k+1] = pi[pi[k]-1] + 1; 如果不相等, 那就再重复这个操作: 查询 pi 数组, 获得前面相等部分的最大首尾匹配长度, 然后再比较相应的字符即可.

这样, 计算 pi 数组的代码也不难理解了.

def calc_pi(t):
    pi = [0] * len(t) # pi[0] = 0
    j = 0 # j 等于 pi[-1]
    for i in range(1, len(t)):
        while j > 0 and t[i] != t[j]: # t[j] 匹配失败, 但是 t[:j] 匹配成功. 注意这里 j = pi[i-1]
            j = pi[j-1] # t[:j] 的后 pi[j-1] 个字符与前 pi[j-1] 个字符相等, 下一步匹配 t[pi[j-1]]
        if t[i] == t[j]:
            j += 1
        pi[i] = j

    return pi

计算 pi 的代码与 KMP 匹配算法非常像. 匹配算法是将模式串与目标串匹配, 而计算 pi 则是将模式串与模式串自己匹配. 在与自己匹配的过程中只会依赖之前已经计算好的 pi 值, 所以说这是一个动态规划算法.

总结

下面的代码将算法的两部分写在了一起. 概括一下, KMP 利用匹配算法, 逐渐推导出 pi 数组, 然后再使用同样的匹配算法, 利用前面求出的 pi 匹配模式串与目标串. 所以说 KMP 算法是一个非常厉害的算法.

def kmp(s, t):
    if not t: return 0
    pi = [0] * len(t)
    j = 0
    for i in range(1, len(t)):
        while j > 0 and t[i] != t[j]:
            j = pi[j-1]
        if t[i] == t[j]:
            j += 1
        pi[i] = j

    j = 0
    for i, c in enumerate(s):
        while j > 0 and t[j] != c:
            j = pi[j-1]
        if t[j] == c:
            j += 1
        if j == len(t):
            return i - j + 1

    return -1