浅谈 KMP 字符串匹配算法

因为上周末 MFE 群里有位成员提到了这个算法,所以那天晚上就稍微认识了一下这个算法,感觉 KMP 对于匹配字符串上的做法还是好神奇的,值得水篇文章整理一下。

很久没有更文了w 最近真的是非常非常非常忙qwq,刚刚送走了第二次阶段考就迎来了 NOIP 的初赛,然后接下来的一整个月都要在忙碌的准备中度过,11 月回来又是期中考qwq 所以你们会发现我的解题报告停更了半个月,这篇文章权当是补偿吧(

KMP 算法

KMP算法是一种改进的字符串匹配算法,由D.E.Knuth,J.H.Morris和V.R.Pratt同时发现,因此人们称它为克努特——莫里斯——普拉特操作(简称KMP算法)。KMP算法的关键是利用匹配失败后的信息,尽量减少模式串与主串的匹配次数以达到快速匹配的目的。具体实现就是实现一个next()函数,函数本身包含了模式串的局部匹配信息。时间复杂度O(m+n)。

以上摘自百科,这段话已经详尽地阐明了什么是 KMP 算法、它的应用场景和高效性的核心原因。下面就不再多做介绍了w。

原始的字符串匹配方法

首先回忆一下,如果我们要匹配两个字符串,一般是怎么做的?我想大部分的人应该会想到这样的答案:

  • 用两个下标指针 i, j 分别指向我们要匹配的目标字符串 (target) 和给我们的模式串 (pattern)。
  • 从 i = 0 作为起点开始,如果 target[i] == pattern[j],就一步一步向右移动 i, j 指针。
  • 当 target[i] != pattern[j] 的时候,就让 i 的起点右移一位,让 j = 0.
  • 重新开始尝试匹配,直到匹配成功(返回 target 与 pattern 匹配的起始为止)或者 i = target.length 时仍然没有结果为止(此时 target 与 pattern 不匹配,返回 -1)。

当然,这样的方法确实是正确的说,但是我们还需要考虑这个效率 efficiency。很明显这个方法的时间复杂度是 O(n*m) 的~ 要是数据量很大的话就可能会超时哟(

那么有没有什么好一点的方法来优化我们匹配两个字符串的方式呢?

优化的思想

下文中,我们把目标串 target 记作 T, 模式串 pattern 记作 P. 首先,让我们考虑这样两个字符串:

target (T): ABCABCABDAB
pattern(P): ABCABD

我们用上面的方法匹配的时候,让 i = j = 0,然后当 T[i] == P[j] 的时候,我们不断让 i, j 右移,直到 i = j = 5 的时候,我们发现 T[i](C) != P[j](D) 了。

按照我们上面的思路,我们会把 i 和 j 回溯:

 ↓ i = 1
ABCABCABDAB
 ABCABD
 ↑ j = 0

这时候第一个字符就不匹配了,所以 i 继续右移,直到 i = 3, j = 0 的时候往下匹配,最后 i = 8, j = 5 的时候匹配成功:

   ↓ i = 3
ABCABCABDAB
   ABCABD
   ↑ j = 0

......

        ↓ i = 8
ABCABCABDAB
   ABCABD
        ↑ j = 5

但是我们发现了一件有趣的事情,就是,如果这件事让我们人为来做的话,当 i, j 从 0 开始匹配到 5 失败的时候,我们不会让 i = 1, 2, … 这样一直试到匹配为止。 因为我们发现我们在 i = 5 的时候匹配失败,P[0] = A, 但是 i = 5 的前面只有 i = 3 的时候 T[i] = A,所以 i = 1, 2 的时候我们根本不需要去匹配。我们能这么想,是因为我们已经知道匹配过的 i = 0~5 的情况了——这也正是 KMP 算法的精髓所在,KMP 算法高效的原因之一正是由于它能够利用已经匹配到的有效信息。

所以呢,上面提到的三个人就想到了一个方案:i 指针可以不动呀,我们只要移动 j 指针就可以了。

     ↓ i = 5
ABCABCABDAB
ABCABD
     ↑ j = 5

这个时候让 i 不动,只移动 j:

     ↓ i = 5
ABCABCABDAB
   ABCABD
     ↑ j = 2

哎,你会发现这样比一起移动 i 和 j 快多了,对吗?这就是 KMP 算法的思想:

利用已经部分匹配这个有效信息,保持 i 指针不回溯,通过修改 j 指针,让模式串尽量地移动到有效的位置。

KMP 算法的灵魂 —— j 指针移动的位置

但是有人会问,我们要怎么知道 j 指针移动到哪个地方才合适呢?因为 j 的下一个位置很关键,如果 j 的下一个位置太提前,那么这个想法无异于没有优化;如果太靠后,那么两个字符串就无法匹配了。所以 j 指针的下一个位置是很重要的。KMP 算法中,用一个叫 next 数组 的东西来记录 j 指针的下一个位置。

next 数组也是 KMP 的难点所在。如果你觉得 next 数组的求解很难理解,这里有几篇文章可以辅助你理解;下面我也会理一下从这几篇文章里整理出来的比较容易理解的东西:

http://www.cnblogs.com/yjiyjige/p/3263858.html http://blog.csdn.net/henuyx/article/details/44653799

我们用 k 表示当 T[i] 与 P[j] 不匹配的时候,j 要移动的下一个位置,则 next[j] = k. 现在你可以猜到,next[j] 存储的是:当 T[i] 与 P[j] 匹配失败时,j 将要移动的下一个位置。

我们还是举上面那个例子吧:

ABCABD
     ↑ j = 5,匹配失败

ABCABD
  ↑ j = 2,把 j 移动到这个位置

有没有发现什么?没错, P[0~1] 和 P[3~4] 是重合的!

再比如说:

       ↓ i = 7
ABCDABCDNAIVE
ABCDABCB
       ↑ j = 7

这个时候又匹配失败了。虽然我们作为人类,很容易看出这两个字符串无论怎么移动都不能匹配。不过我们这个时候还是要假装尝试一下,我们会这样移动 j 指针:

移动前:
       ↓ i = 7
ABCDABCDNAIVE
ABCDABCB
       ↑ j = 7
移动后:
       ↓ i = 7
ABCDABCDNAIVE
    ABCDABCB
       ↑ j = 3

哎,是不是还是刚才的那个规律呢,模式串的 P[0~2] 和 P[4~6] 是重合的,而恰恰又是 j 移动的位置 3 把它们隔开了。

现在我们可以总结出这样的性质了:当 T[i] 与 P[j] 匹配失败的时候,j 移动到下一个位置 k,k 即为使得 P[0, k-1] = P[j-k, j-1] (或者说 P[k+1, j-1]) 的那个值。

引用一下一篇文章的证明:

当T[i] != P[j]时 有T[i-j ~ i-1] == P[0 ~ j-1] 由P[0 ~ k-1] == P[j-k ~ j-1] 必然:T[i-k ~ i-1] == P[0 ~ k-1]

进一步我们还可以发现,这个 k 的位置完全取决于模式串 P 自身的性质,与目标串 T 无关。所以我们就可以预处理计算出对于每一个 j 的 k 值了。就如我们刚才所说的我们把 k 存储在 next 数组当中,那么接下来我们就来计算 next 数组。

next 数组的计算

先温习一下 next 数组的含义和作用:

next[j] 存储的是:当 T[i] 与 P[j] 匹配失败时,j 将要移动的下一个位置。

接下来我们分类讨论一下:

  1. 当 j = 0 的时候,显然 j 已经不能再往前移动了,我们这个时候应该保持 j 不动,然后向右移动 i 指针继续匹配。所以我们让 next[0] = -1,告诉 KMP 主算法已经不能再移动 j 指针了。
  2. 当 j = 1 的时候,显然 j 只能移动到 0,所以 next[1] = 0.
  3. 那么其他情况呢?如果 P[j] = P[k] 的话,那么 next[j+1] = next[j] + 1.
  4. 最后一种情况就是 P[j] != P[k],那么我们就让 k = next[k], 重复这个步骤直到符合上面三个条件之一为止。

至于为什么这样做,这里主要介绍一下 3 和 4:

为什么 P[j] = P[k] 的时候,next[j+1] = next[j] + 1 呢?

考虑这样一个字符串:

index 0 1 2 3 4 5 6 7 8
str   A B C D A B C D E
next -1 0 0 0 0 0 0 3 4  

通过计算我们可以发现 next[7] = 3, 因为 3 的左边和右边都是 ABC,这个时候我们发现 P[7] = P[3] = D.

所以,我们发现,当 j = j+1 = 8 之后,next[8] = 4 = next[7] + 1. 这个规律也是可以证明的。这个证明同样是引用:

因为在P[j]之前已经有P[0 ~ k-1] == p[j-k ~ j-1]。(next[j] == k) 这时候现有P[k] == P[j],我们可以得到P[0 ~ k-1] + P[k] == p[j-k ~ j-1] + P[j]。 即:P[0 ~ k] == P[j-k ~ j],即next[j+1] == k + 1 == next[j] + 1。

那 P[j] != P[k] 的时候呢?

如果没有赋值的时候,next[j] 默认是等于 0 的,否则 k 就是上一次计算的 k 值。 这样,当 P[j] != P[k] 的时候,我们就一直让 k = next[k],直到 k 符合上面的条件之一为止。

鉴于……这个东西我也讲不清楚qwq 所以我找到了一篇总结的挺好的文章:https://www.zhihu.com/question/21474082。 虽然内容不长,但是我就不贴过来了。

那么,我们可以整理出计算 next 数组的整个过程了:

初始化 next[] = 0
next[0] ← -1
j ← 0, k ← -1, len ← (pattern 长度 - 1)
while j < len do
  if k = -1 || pattern[j] = pattern[k] then
    j ← j+1, k ← k+1
    next[j] = k
  else
    k ← next[k]

下面是一个用 C++ 实现的方式:

void calcNextArr()
{
  next[0] = -1;
  int j = 0, k = -1, len = strlen(pattern) - 1;
  while (j < len)
  {
    if (k == -1 || pattern[j] == pattern[k]) {
      next[++j] = ++k;
      // 如果要做时间复杂度优化,就加上下面的代码
      if (pattern[j] == pattern[k]) {
        next[j] = next[k];
      }
    } else {
      k = next[k];
    }
  }
}

优化 next 数组的求解

注意到我们在实现上面的伪代码的时候多了两句:

if (pattern[j] == pattern[k]) {
  next[j] = next[k];
}

为什么要加这一句呢?我们可以看出如果 P[j] = P[k] 的时候,让 next[j] = next[k]. 但是我们上文不是说到,如果 P[j] = P[k] 的时候,next[j+1] = next[j] + 1 嘛,为什么又回去了呐?考虑这样的一个模式串:

ABCABDNAIVE
ABCABCABDAB
     ↑ j = 5

如上,当我们在 j = 5 的地方匹配失败的时候,按照我们上面的算法,next[j] = 2,所以 j 会回溯到 2 的地方,指向 C. 但是我们发现就算 j 回溯到 j = 2 的时候,P[j] 仍然等于 C,依旧不能和目标串匹配啊。所以我们还是要把 j 继续回溯,那么这时候 next[j] = 0. 相信大家想到了,只要我们让 j = 5 的时候直接跳回 0,就可以省去多跳一步的步骤了,这就是这段代码的作用。

KMP 主算法

了解了 next[] 数组之后,你已经完全了解 KMP 不远了。 接下来是 KMP 算法的整个流程 :

T ← 目标串,P ← 模式串
i ← 0, j ← 0
next[] ← calcNext(P)                   // 对模式串 P 计算其 next 数组的值
while i < T.length && j < P.length do
  // 如果 j 已经指向模式串的起点了,就把 i 下移一位,j 置零;或者 T[i] 和 P[j] 匹配了,那么就把 i, j 一起下移一位
  if j == -1 || T[i] == P[j] then
    i ← i+1, j ← j+1
  // 如果不匹配,那就让 j 回溯到 next[j] 位置
  else
    j ← next[j]
if j >= P.length                  // 此时匹配成功
  return i-j                      // i-j 即为 T 与 P 匹配时 T 串开始的位置
else
  return -1                       // 否则返回不匹配,用 -1 表示

仍旧是一份 C++ 实现的模板:

int KMP()
{
  int i = 0, j = 0;
  int tlen = strlen(target), plen = strlen(pattern);
  while (i < tlen && j < plen)
  {
    if (j == -1 || target[i] == pattern[j]) {
      i++, j++;
    } else {
      j = next[j];
    }
  }
  
  if (j == plen) {
    return i - j;
  } else {
    return -1;
  }
}

解决实际问题

HDU2087

地址在这里:http://acm.hdu.edu.cn/showproblem.php?pid=2087

分析题目我们可以发现这就是一道简单的裸题字符串匹配,我们当然可以直接上 KMP 啦。下面是这题的 AC 代码:

#include <cstdio>
#include <cstring>

const int MAXN = 1e3 + 10;
int n[MAXN];
char target[MAXN], pattern[MAXN];

void init()
{
  memset(target, 0, sizeof target);
  memset(pattern, 0, sizeof pattern);
  memset(n, 0, sizeof n);
}

void calcNext()
{
  n[0] = -1;
  int i = 0, k = -1, len = strlen(pattern) - 1;
  while (i < len)
  {
    if (k == -1 || pattern[i] == pattern[k]) {
      n[++i] = ++k;
      if (pattern[i] == pattern[k]) {
        n[i] = n[k];
      }
    } else {
      k = n[k];
    }
  }
}

int KMP()
{
  int ans = 0, i = 0, j = 0;
  int tlen = strlen(target), plen = strlen(pattern);
  
  while (i < tlen)
  {
    if (j == -1 || target[i] == pattern[j]) {
      i++, j++;
    } else {
      j = n[j];
    }
    if (j == plen) {
      j = 0;
      ans++;
    }    
  }
  return ans;
}

int main()
{
  while (~scanf("%s", target))
  {
    if (target[0] == '#')
      break;
    scanf("%s", pattern);
    calcNext();
    int ans = KMP();
    printf("%d\n", ans);
    init();
  }
  return 0;
}

luogu 3375

这题也是个 KMP,可能有多个匹配,那么它要输出每个匹配的位置;以及这题的 next 数组不能优化,因为他还要输出= =我就是用了优化的 next 数组然后 WA 了好多次( 总之写起来大概像下面这样:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int MAXN = 1000050;
int next[MAXN], n[MAXN];
int len1, len2;
char target[MAXN], pattern[MAXN];

void calcNext()
{
    int j = 0, k = -1;
    next[0] = -1, n[0] = -1;
    while (j < len2)
    {
        if (k == -1 || pattern[j] == pattern[k]) {
            next[++j] = ++k;
            n[j] = k;
            if (pattern[j] == pattern[k]) {
                n[j] = n[k];
            }
        } else {
            k = next[k];
        }
    }
}

void kmp()
{
    int i = 0, j = 0;
    while (i < len1)
    {
        if (j == -1 || target[i] == pattern[j]) {
            i++, j++;
        } else {
            j = n[j];
        }
        
        if (j == len2) {
            printf("%d\n", i - j + 1);
            j = n[j];
        }
    }
}

int main()
{
    scanf("%s%s", target, pattern);
    len1 = strlen(target);
    len2 = strlen(pattern);
    calcNext();
    kmp();
    for (int i = 1; i <= len2; i++)
    {
        printf("%d ", next[i]);
    }    
    printf("\n");
    return 0;
}

至此,我们对 KMP 的算法就有了一个比较完整的了解了。