深入理解 KMP 算法

字符串匹配是非常常用的算法. 它要做的是, 对于模式串 t, 找到它在目标串 s 中第一个出现的位置. 例如 t = "abcab", s = "ababcabd", 算法就应该返回 2. 说起字符串匹配, 不得不提到 KMP 算法. KMP 是一个很厉害的算法, 它非常巧妙, 能在线性时间复杂度内完成字符串匹配. 本文我们来了解它.

本文会用到 Python 的切片语法表示子串. 如果你不熟悉 Python, 只需要记住: s[:k] 表示 s 长度为 k 的前缀, s[-k:] 表示 s 长度为 k 的后缀. 此外字符串的下标从 0 开始, 因此 s[:k] 的最后一个字符为 s[k-1].

匹配算法

粗略地说, KMP 算法利用了字符串的一个性质: 字符串的尾部有可能与自己的首部相匹配. 比如说字符串 abcab, 它的尾部有两个字符可以与自己的首部相匹配:

1
2
3
   abcab
||
abcab

用形式化的语言描述: 对于一个字符串 t, 存在一个最大的数 0 <= m < len(t), 使得 t[:m] == t[-m:] 成立. 我们称 m最大首尾匹配长度. 注意 m 必须小于 t 的长度, 不然没有意义 – 任何字符串总是与自身相等.

KMP 首先会为模式串的所有前缀求得最大首尾匹配长度, 存储在数组 pi 中. 前缀 t[:i] 的最大首尾匹配长度为 pi[i-1]. 例如字符串 p = "abcab"pi 数组为 [0, 0, 0, 1, 2], 因为:

  • p[:1] = "a", 最大首尾匹配长度为 0, pi[0] = 0
  • p[:2] = "ab", 最大首尾匹配长度为 0, pi[1] = 0
  • p[:3] = "abc", 最大首尾匹配长度为 0, pi[2] = 0
  • p[:4] = "abca", 最大首尾匹配长度为 1, pi[0] = 1, 因为有
    1
    2
    3
       abca
    |
    abca
  • p[:5] = "abcab", 最大首尾匹配长度为 2, pi[0] = 2, 因为
    1
    2
    3
       abcab
    ||
    abcab

我们先不管 KMP 是怎样求 pi 数组的, 先看看 KMP 是怎样利用 pi 数组做匹配的.

假设字符串 s 与模式串 t 匹配. 我们从第 0 个字符开始匹配, 接着是第 1 个, 第 2 个, 一直到第 k 个字符, 都匹配成功了; 然而第 k + 1 个字符匹配失败了.

kmp_1

匹配失败了, 那怎么办呢? 重点来了. 第 0 至 k 个字符匹配成功了, 这段子串等于 t[:k+1]. pi 数组告诉我们, 这段子串的后 pi[k] 个字符正好等于它的前 pi[k] 个字符.

kmp_2

这样下一步我们就让 s[k+1]t[pi[k]] 相比较, 如果相等, 就继续匹配后面的字符; 如果不相等 – 没关系, t[:pi[k]] 已经匹配成功, 我们再去查询 pi 数组得到这段子串的最大首尾匹配长度, 再按照同样的方式比较相应的字符即可.

有了这个思路, 我们就不难写出 KMP 算法的代码了:

1
2
3
4
5
6
7
8
9
10
11
12
def kmp(s, t):
pi = calc_pi(t) # 先不管如何计算 pi
j = 0
for i, c in enumerate(s):
while j > 0 and t[j] != c: # t[j] 匹配失败, 但是 t[:j] 匹配成功
j = pi[j-1] # t[:j] 的后 pi[j-1] 个字符与前 pi[j-1] 个字符相等, 下一步匹配 t[pi[j-1]]
if t[j] == c:
j += 1
if j == len(t):
return i - j + 1 # 返回起始位置

return -1

求 pi 数组

既然 pi 数组这么好用, 怎么求它呢? 首先很容易想到 pi[0] = 0, 因为最大首尾匹配长度需小于字符串长度. 如果我们能够用 pi[0], pi[1], ..., pi[k] 推出 pi[k+1], 我们就能求出整个 pi 数组了.

假设我们求出了 pi[k], 也就是 t[:k+1] 的最大首尾匹配长度.

kmp_3

那么 pi[k+1] 也就是 t[:k+2] 的最大首位匹配长度是多少呢? 这需要分两种情况讨论. 假设 t[k+1] == t[pi[k]], 这种情况很简单, pi[k+1] = pi[k] + 1.

kmp_4

要是不相等呢? 没关系, 前面那几个字符, 也就是 t[:pi[k]], 不是匹配上了么? 根据前面刚求出来的 pi 数组, 我们知道对于 t[:pi[k]] 这个子串, 后 pi[pi[k]-1] 个字符与前 pi[pi[k]-1] 个字符相等! 这就回到前面匹配算法的情况了.

kmp_5

接下来我们只需要让 t[k+1]t[pi[pi[k]-1]] 相比较. 如果相等, 那么 pi[k+1] = pi[pi[k]-1] + 1; 如果不相等, 那就再重复这个操作: 查询 pi 数组, 获得前面相等部分的最大首尾匹配长度, 然后再比较相应的字符即可.

这样, 计算 pi 数组的代码也不难理解了.

1
2
3
4
5
6
7
8
9
10
11
def calc_pi(t):
pi = [0] * len(t) # pi[0] = 0
j = 0 # j 等于 pi[-1]
for i in range(1, len(t)):
while j > 0 and t[i] != t[j]: # t[j] 匹配失败, 但是 t[:j] 匹配成功. 注意这里 j = pi[i-1]
j = pi[j-1] # t[:j] 的后 pi[j-1] 个字符与前 pi[j-1] 个字符相等, 下一步匹配 t[pi[j-1]]
if t[i] == t[j]:
j += 1
pi[i] = j

return pi

计算 pi 的代码与 KMP 匹配算法非常像. 匹配算法是将模式串与目标串匹配, 而计算 pi 则是将模式串与模式串自己匹配. 在与自己匹配的过程中只会依赖之前已经计算好的 pi 值, 所以说这是一个动态规划算法.

总结

下面的代码将算法的两部分写在了一起. 概括一下, KMP 利用匹配算法, 逐渐推导出 pi 数组, 然后再使用同样的匹配算法, 利用前面求出的 pi 匹配模式串与目标串. 所以说 KMP 算法是一个非常厉害的算法.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def kmp(s, t):
if not t: return 0
pi = [0] * len(t)
j = 0
for i in range(1, len(t)):
while j > 0 and t[i] != t[j]:
j = pi[j-1]
if t[i] == t[j]:
j += 1
pi[i] = j

j = 0
for i, c in enumerate(s):
while j > 0 and t[j] != c:
j = pi[j-1]
if t[j] == c:
j += 1
if j == len(t):
return i - j + 1

return -1

深入理解 KMP 算法
https://luyuhuang.tech/2021/12/25/kmp.html
Author
Luyu Huang
Posted on
December 25, 2021
Licensed under