八数码和十五数码问题是搜索算法中比较经典的问题。这个问题涉及的方面比较广,而且解答的方法也比较多。最近因为在一次 team contest 中遇到了相关的题目,之前一直没有好好钻研一下这类问题,最近又很寂寞,因此就在这星期找了一个时间,以八数码问题为载体,研究了该问题涉及的几个经典算法。
什么是八数码问题?
让我们先对八数码问题的模型背景有一个初步的了解:在 3×3 的棋盘上,摆有八个棋子,每个棋子上标有 1 至 8 的某一数字。棋盘中留有一个空格,空格用 0 来表示。空格周围的棋子可以移到空格中。对于十五数码,则是在 4×4 的棋盘上摆棋子,数字变为 1~15,其它不变。我们将各个数字的不同排列位置称为状态。例如:
1 | 2 | 3 |
---|---|---|
4 | 0 | 6 |
7 | 5 | 8 |
对于这类数码问题,比较常见的两个问题模型是这样的:
- 给定两个状态 $s$ 和 $t$,问能否经过有限次的移动变换,使初状态 $s$ 变为末状态 $t$.
- 给定两个状态 $s$ 和 $t$,问最少需要将 $s$ 经过几次变换,使初状态 $s$ 变为末状态 $t$.
接下来讲讲这两个问题的解决策略。
状态之间是否可达
要解决这个问题,我们需要先知道,每次进行移动操作之后,有什么东西变了,有什么东西没变。以八数码为例,直接给出结论:显然,变的是 0 和与其交换的数的位置;而不变的是两个局面下,逆序对个数的奇偶性。当然,一般的人是很难想到八数码问题和逆序对奇偶性能扯上关系的,但是一旦点出来之后,证明的思路就是清晰的。
首先,前文说到 0 是空格,因此我们在考虑状态的逆序对问题的时候,就不需要考虑 0 对逆序对数量的影响。也就是说我们只需要考虑:把整个 $3 \times 3$ 矩阵去掉 0 之后写成一个序列,这个序列逆序对的数量,以及两个状态的逆序对奇偶性是否相同。如果两个状态的奇偶性相同,则这两个状态的奇偶性相互可达;否则相互不可达。
其次,我们再考虑移动数字对逆序对奇偶性的影响。由于 0 并不参与逆序对的统计,因此将 0 左右移动,写成的序列并不变,并不影响逆序对的数量。而将 0 上下移动的时候,相当于有一个数字被后移或前移了 2 位。如 $a, b, c, d, e, f, g, h$ 八个数中,将 $c$ 后移 2 位,得到序列 $a, b, d, e, c, f, g, h$. 显然,逆序对可能发生改变的部分只有 $d, e, c$ 三个数字。根据线性代数的知识,我们可以知道,两个相邻数字进行一次交换,逆序对的奇偶性改变;而上述操作可以视为 $c$ 分别于 $d, e$ 对调一次,逆序对的奇偶性改变了两次,和原来相比相当于没有改变。
因此,用归纳法我们可以证明:只要两个状态的序列逆序对奇偶性相同,他们就一定互相可达;否则一定互相不可达,因为交换 0 并不影响逆序对的奇偶性。至于怎么移动才能可达,这就不在这个问题的考虑范围了。
再说说十五数码。十五数码中左右移动的情况仍然不改变逆序对的奇偶性,但是上下移动呢?此时相当于将元素前移或后移 3 位,逆序对的奇偶性一定会改变;但是如果再次上下交换,相当于移动 6 次,奇偶性又与原来的状态一样了。因此对于偶数数码(这里特指 $n$ 为偶数,也即 $n^2-1$ 数码问题)的问题,仅仅判断逆序对的奇偶性并不能确定答案。此时还需要加上初状态和末状态中空格间的行差,即判断:末状态逆序对数量+初末状态空格的行数差 与 初状态逆序对数量 的奇偶性是否相同。这样无论是奇数次还是偶数次的上下移动,加上行差之后的奇偶性都不变。
这个问题也能推广到 $n \times m$ 的数码问题中。
说到求逆序对,在八数码和十五数码问题中,事实上用 $O(n^2)$ 的方法效率也不会太低,因为 $n$ 比较小。然而在遇到 $n$ 规模比较大的问题就比较麻烦。此时最好的办法还是用对序列归并排序的方式来求逆序对。例如:POJ2893, 这是一个 $n \times m$ 的棋盘下的问题。根据我们的分析,我们可以对列数 $m$ 进行分类讨论: 如果 $m$ 为奇数,则直接求状态逆序对数是否为偶数(因为本题中初状态的逆序对数量为 0);如果 $m$ 是偶数,则找到 0 所在的行数 $t$,判断 $ans + n - t$ 是否为偶数即可。至于归并排序求逆序对的代码,我觉得紫书里那个写的很简洁,因此可以直接抄来。
#include <cstdio>
const int MAXN = 1000000;
int a[MAXN], b[MAXN], c[MAXN];
int n, m, ans = 0;
void merge_sort(int x, int y)
{
if (y - x > 1) {
int m = x + (y - x) / 2;
int p = x, q = m, i = x;
merge_sort(x, m);
merge_sort(m, y);
while (p < m || q < y) {
if (q >= y || (p < m && a[p] <= a[q]))
b[i++] = a[p++];
else {
b[i++]= a[q++];
ans += m - p;
}
}
for (i = x; i < y; i++)
a[i] = b[i];
}
}
int main()
{
while (scanf("%d%d", &n, &m) != EOF && n != 0) {
int tmp, cnt = 0, line = -1;
ans = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
scanf("%d", &tmp);
if (tmp == 0) {
line = i + 1;
continue;
}
a[cnt++] = tmp;
}
}
merge_sort(0, cnt);
if (m % 2 == 0) {
if ((ans + n - line) % 2 == 0) {
printf("YES\n");
} else {
printf("NO\n");
}
} else {
if (ans % 2 == 0) {
printf("YES\n");
} else {
printf("NO\n");
}
}
}
return 0;
}
两个状态间的最少移动次数
接下来讨论的第二个问题就是一个经典的搜索问题了。这里为了叙述的方便,我们只讨论八数码问题,并且假定已知两个状态是相互可达的。如果题目未知相互可达,则只需要加上 (1) 的代码判断即可。并且给出一道例题:https://vijos.org/p/1360 。
首先我们可以很容易确定这应该是一个 BFS 的问题——因为要求的是最小的移动次数,那么我们从起点搜索,每次拓展将 0 与上下左右四个数交换后得到的新状态,直到搜到结果为止。由 BFS 的性质我们容易知道,我们搜到结果的时候的移动次数一定是最小次数。这道题的关键在于,搜索的策略。
上面朴素的搜索方式有什么弊端呢?一个最大的问题就是它会造成重复的搜索。设有一个状态 $s$,将 0 左移得到新的状态 $s_1$,而 $s_1$ 又可以将 0 右移得到状态 $s_2$, 但显然有 $s$ = $s_2$. 这样的话就造成了很多不必要的搜索。浪费时间。另外,我们还可以通过使用不同的搜索方式,来改进搜索的效率。
首先考虑重复搜索的避免,为了方便状态的表示,我们还是将 $3 \times 3$ 方格里的 9 个数写成一个整型数的序列,方便我们后面的操作。这里有两个状态判重策略:一个是用 STL 中的 map 判重,定义 map<int, int> mp
来判断某个序列是否已搜索过或在队列中,从而避免重复拓展队列中的状态;第二个是通过 Hash + 链表的方式,找一个 Hash 函数,将状态整数映射到一个节点数不超过 $9!=362880$ (由于没有重复数字出现,因此有效的状态数量不超过 9 的全排列数 $9!$)的链表中。
先上一个最简单的 STL map 的做法。需要注意,当我们把状态写成一个序列的时候,要注意处理边界的问题。例如当 0 位于 (2, 1) 的时候,它不能左移;位于 (1, 3) 的时候,它不能上移,也不能右移……所以这个细节需要特别处理一下。程序如下:
#include <cstdio>
#include <cstring>
#include <queue>
#include <map>
#include <algorithm>
using std::queue;
using std::map;
using std::swap;
struct s {
char a[10];
int step;
int getState() {
int res = 0;
sscanf(a, "%d", &res);
return res;
}
};
queue<s> q;
map<int, int> mp;
int main()
{
s init;
int fin = 123804765;
scanf("%s", init.a);
init.step = 0;
mp[init.getState()]++;
q.push(init);
while (!q.empty()) {
s cur = q.front();
q.pop();
// 到达最终状态,由 BFS 性质可知此时一定最小
if (cur.getState() == fin) {
printf("%d", cur.step);
break;
}
cur.step++;
int pos = -1;
for (int i = 0; i < 9; i++)
if (cur.a[i] == '0') {
pos = i;
break;
}
int state;
// 将 0 与上下左右四个数交换得到四个新状态,注意判断合法性
if (pos % 3 != 0) {
swap(cur.a[pos], cur.a[pos - 1]);
state = cur.getState();
if (mp.count(state) == 0) {
q.push(cur);
mp[state]++;
}
swap(cur.a[pos], cur.a[pos - 1]);
}
if (pos % 3 != 2) {
swap(cur.a[pos], cur.a[pos + 1]);
state = cur.getState();
if (mp.count(state) == 0) {
q.push(cur);
mp[state]++;
}
swap(cur.a[pos], cur.a[pos + 1]);
}
if (pos > 2) {
swap(cur.a[pos], cur.a[pos - 3]);
state = cur.getState();
if (mp.count(state) == 0) {
q.push(cur);
mp[state]++;
}
swap(cur.a[pos], cur.a[pos - 3]);
}
if (pos < 6) {
swap(cur.a[pos], cur.a[pos + 3]);
state = cur.getState();
if (mp.count(state) == 0) {
q.push(cur);
mp[state]++;
}
swap(cur.a[pos], cur.a[pos + 3]);
}
}
return 0;
}
上面的代码在 Vijos 上跑了最多 60 多 ms. 当然,用 STL map 的工业做法显然不是我们今天的重点。让我们来尝试一些新的东西吧,比如用哈希表判重。所谓哈希表,就是 Hash + 链表的简称,它的实现思路是这样的:
- 确定映射后的最大状态数 MAXN;
- 找到一个质数 p < MAXN (这样做是为了减少 Hash 结果相同的状态过多导致查找时影响效率);
- 根据 1, 2 确定哈希函数 $h(x)$,确保 $h(x)$ 是一个 0~MAXN 的结果;
- 建立哈希链表(有些像图论中的前向星)
head[MAXN], next[MAXN], state[MAXN]
,用于存储映射后每个结果对应原先的状态集合;例如原状态为 $s_1, s_2$,通过 $h(x)$ 映射后是相同的哈希值$h$,那么我们可以通过哈希链表,从i = head[h]
开始找,每次沿着next[i]
找哈希值为 $h$ 的下一个节点,就可以知道对应的状态是否已访问过;
写 Hash 表要求我们对链表的基本操作比较熟悉。至于哈希函数的寻找,因为我们这里的状态比较少,所以直接应用最简单的哈希函数:$h(x) = x \% p$ 即可。在其它的题目中,我们也可以选取例如这样的哈希函数:$h(a[]) = (\Sigma_{i=0}^{len(a)-1}a[i] \% p + (\Pi _{i=0}^{len(a)-1} a[i]) \% p)$… 哈希函数的选取是既有技巧又玄学的,选取的好对程序的效率有很大的影响。按照这个思想我们来实现它:
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <queue>
using std::queue;
using std::swap;
const int MAXN = 362881;
const int MOD = 362867;
struct s{
char a[9];
int step;
int getState() {
int res = 0;
sscanf(a, "%d", &res);
return res;
}
};
int head[MAXN], next[MAXN], state[MAXN], cnt = 1;
queue<s> q;
int getHash(int x)
{
return x % MOD;
}
bool insert(int st)
{
int val = getHash(st);
for (int i = head[val]; i; i = next[i]) {
if (state[i] == st) {
return true;
}
}
state[cnt] = st;
next[cnt] = head[val];
head[val] = cnt++;
return false;
}
int main()
{
memset(head, 0, sizeof head);
int finalState = 123804765;
s init;
init.step = 0;
scanf("%s", init.a);
insert(init.getState());
q.push(init);
while (!q.empty()) {
s cur = q.front();
q.pop();
if (cur.getState() == finalState) {
printf("%d", cur.step);
return 0;
}
cur.step++;
int pos = -1;
for (int i = 0; i < 9; i++)
if (cur.a[i] == '0') {
pos = i;
break;
}
if (pos % 3 != 0) {
swap(cur.a[pos], cur.a[pos - 1]);
if (!insert(cur.getState())) {
q.push(cur);
}
swap(cur.a[pos], cur.a[pos - 1]);
}
if (pos % 3 != 2) {
swap(cur.a[pos], cur.a[pos + 1]);
if (!insert(cur.getState())) {
q.push(cur);
}
swap(cur.a[pos], cur.a[pos + 1]);
}
if (pos > 2) {
swap(cur.a[pos], cur.a[pos - 3]);
if (!insert(cur.getState())) {
q.push(cur);
}
swap(cur.a[pos], cur.a[pos - 3]);
}
if (pos < 6) {
swap(cur.a[pos], cur.a[pos + 3]);
if (!insert(cur.getState())) {
q.push(cur);
}
swap(cur.a[pos], cur.a[pos + 3]);
}
}
return 0;
}
这个代码相对 STL map 的版本效率有些许提高,但仍然还是跑了 50 多 ms。有没有什么更好的优化策略?既然判重的问题我们已经解决了,接下来我们就考虑从搜索的算法上下手着手优化。我们还在一直用着最朴素的 BFS 呢,现在是时候优化一下它了。
双向广搜 (DBFS)
我们用的最朴素的 BFS,是从搜索起点开始,一步一步由当前状态拓展出新的状态,直到拓展出目标状态为止,只需要一个 BFS 对列。而双向广搜的原理是,建立两个队列 q1, q2, 分别存储由起点和终点拓展得到的状态的队列。每一次选择其中一个队列拓展状态,直到有一个状态分别被两个队列拓展过,该状态就是我们的目标状态。具体的操作是这样的:
- 建立两个队列 q1, q2。初始时将起始状态 $s$ 放入 q1, 末状态 $t$ 放入 q2; 使用 map 或其它工具来记录某一状态是否被某一搜索队列拓展过;
- 每次选择 q1, q2 内节点较少的那个拓展状态;取决于题目场景的不同,两个队列拓展状态的策略可能相同,也可能相反;
- 每搜索一个状态,就判断当前状态是否在另一队列中出现过。如果是,则由此节点在起始状态和末状态之间“建立起了一个通路”,说明答案已经被找到,综合两个队列的结果即为所求。
具体到这道题中,我们的操作是这样的:初状态放入 q1, 末状态丢进 q2,并分别用 map 容器 mp1 和 mp2 记录 q1 和 q2 搜索过的节点(状态);如果某个状态在另一个队列中出现过(也即在对应的 map 容器中计数不为 0),说明找到了答案,并且最小移动次数即为双向搜索的步数之和。至于记录步数,我们只需要将 mp1, mp2 的 value 设为“搜索到某状态所需的步数”即可。
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>
using std::queue;
using std::map;
using std::swap;
struct s{
char a[9];
int step;
int getState() {
int res = 0;
sscanf(a, "%d", &res);
return res;
}
};
queue<s> q1, q2;
s init, fin;
map<int, int> mp1, mp2;
int finState;
void extend(int x)
{
s cur = x ? q2.front() : q1.front();
x ? q2.pop() : q1.pop();
int state = cur.getState();
if (x && mp1.count(state)) {
printf("%d", mp1[state] + cur.step);
exit(0);
} else if (!x && mp2.count(state)) {
printf("%d", mp2[state] + cur.step);
exit(0);
}
cur.step++;
int pos = -1;
for (int i = 0; i < 9; i++)
if (cur.a[i] == '0') {
pos = i;
break;
}
if (pos % 3 != 0) {
swap(cur.a[pos], cur.a[pos - 1]);
state = cur.getState();
if (!(x ? mp2.count(state) : mp1.count(state))) {
x ? q2.push(cur) : q1.push(cur);
if (x) {
mp2[state] = cur.step;
} else {
mp1[state] = cur.step;
}
}
swap(cur.a[pos], cur.a[pos - 1]);
}
if (pos % 3 != 2) {
swap(cur.a[pos], cur.a[pos + 1]);
state = cur.getState();
if (!(x ? mp2.count(state) : mp1.count(state))) {
x ? q2.push(cur) : q1.push(cur);
if (x) {
mp2[state] = cur.step;
} else {
mp1[state] = cur.step;
}
}
swap(cur.a[pos], cur.a[pos + 1]);
}
if (pos > 2) {
swap(cur.a[pos], cur.a[pos - 3]);
state = cur.getState();
if (!(x ? mp2.count(state) : mp1.count(state))) {
x ? q2.push(cur) : q1.push(cur);
if (x) {
mp2[state] = cur.step;
} else {
mp1[state] = cur.step;
}
}
swap(cur.a[pos], cur.a[pos - 3]);
}
if (pos < 6) {
swap(cur.a[pos], cur.a[pos + 3]);
state = cur.getState();
if (!(x ? mp2.count(state) : mp1.count(state))) {
x ? q2.push(cur) : q1.push(cur);
if (x) {
mp2[state] = cur.step;
} else {
mp1[state] = cur.step;
}
}
swap(cur.a[pos], cur.a[pos + 3]);
}
}
int main()
{
init.step = 0;
fin.step = 0;
scanf("%s", init.a);
sprintf(fin.a, "%s", "123804765");
finState = fin.getState();
q1.push(init);
mp1[init.getState()] = 0;
q2.push(fin);
mp2[fin.getState()] = 0;
while (!q1.empty() || !q2.empty()) {
if (q1.size() < q2.size()) {
extend(0);
} else {
extend(1);
}
}
return 0;
}
这样的优化看起来思想很简单,但是——在 Vijos 上这个程序的峰值运行时间高达 3ms! 从 50+ms 到个位数的飞跃,简直让人懵逼了有没有!那么话说回来,为什么双向广搜的效率比普通的 BFS 高了这么多倍呢?看下面这张示意图:
上面的部分是朴素 BFS 的搜索树拓展情况,我们可以发现它的搜索树呈现一个三角形,步数越多,搜索树越深;而结果可能在搜索树的某一层的一个点,对比起来,这一层和前几层的搜索规模可能显得很大;但是看下面 DBFS 的情况,从 origin 和 target 同时画两个三角形,当两个三角形相交的时候就说明找到了答案。并且这个搜索的范围相比单向 BFS 还更小了(黄色的部分,就是不需要搜索的节点),因此效率就有了显著的提升。
启发式搜索 (A*)
最后的最后,我们再来讲最后一种方法——启发式搜索。鉴于这种搜索方式的学问很大,并且还能牵出 IDA* 等算法,因此在后面我会再写另一篇文章来详细探寻一下这种算法。
顾名思义,启发式搜索,有启发才有搜索,它是利用问题拥有的启发信息来引导搜索,达到减少搜索范围、降低问题复杂度的目的。这种方法通过一个估价策略指导搜索向最有希望的方向前进,降低了复杂性。然而,启发式搜索是有一些玄学的,之所以这么说是因为它很容易出错,极可能因为估价函数选的不好而得到错误的解或非最佳的解,甚至可能还会反向增加复杂度。
说到估价函数,它一般是这样的:$f(x) = g(x) + h(x)$,其中 $g(x)$ 表示的是从初始节点到节点 $x$ 付出的实际代价,$h(x)$ 为从节点 $x$ 到目标节点的最优路径的估计代价。搜索的时候我们很容易得到 $g(x)$,因此我们这个算法的“启发性”就主要体现在 $h(x)$ 中。正确地选取 $h(x)$ 是解决问题的关键,而 $h(x)$ 的具体定义又随着问题的不同而不同。简略地说,启发式搜索的步骤大概像这样:
- 创建一个按照 $f(x)$ 有序(大到小or小到大)排列的队列,这里我们可以用 priority_queue 来实现
- 将初始节点 $s$ 放入该队列中
- 取出优先队列中位于顶部(一般是 $f(x)$ 最小或最大)的状态,判断它是否为目标状态,如果是则直接退出,搜索成功
- 否则,拓展当前节点,得到新的节点 $s_1, s_2, …, s_n$,对每个节点使用估价函数进行评估 $f(s_1), f(s_2)…f(s_n)$,将其作为对应节点的权值加入优先队列中
- 转到 3 继续搜索
在这道题中,我们的估价函数选取策略是这样的:$g(x)$ 表示当前的搜索步数,$h(x)$ 表示当前状态与目标状态对应位置不同数的个数,如状态 $123456780$ 和 $123450786$, $h(x) = 2$. 我们希望 $f(x)$ 越小越好,因为这样有利于我们搜索到答案。如果这里方向搞错的话,那么会得到很大的结果,显然就不是我们想要的答案了。按照这个思路,我们事先的代码如下:
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>
using std::priority_queue;
using std::map;
using std::swap;
struct s {
char a[9];
int step, ecost;
int getState() {
int res = 0;
sscanf(a, "%d", &res);
return res;
}
void evaluate(s fin) {
ecost = 0;
for (int i = 0; i < 9; i++)
if (a[i] != fin.a[i])
ecost++;
}
bool operator < (const s a) const {
return step + ecost > a.step + a.ecost;
}
};
int finState;
map<int, int> mp;
priority_queue<s> pq;
int main()
{
s fin, init;
init.step = 0;
fin.step = 0;
scanf("%s", init.a);
sprintf(fin.a, "%s", "123804765");
finState = fin.getState();
init.evaluate(fin);
init.step = 0;
pq.push(init);
mp[init.getState()]++;
while (!pq.empty()) {
s cur = pq.top();
pq.pop();
if (cur.getState() == finState) {
printf("%d", cur.step);
return 0;
}
cur.step++;
int pos = -1;
for (int i = 0; i < 9; i++)
if (cur.a[i] == '0') {
pos = i;
break;
}
int state;
if (pos % 3 != 0) {
swap(cur.a[pos], cur.a[pos - 1]);
state = cur.getState();
if (mp.count(state) == 0) {
cur.evaluate(fin);
pq.push(cur);
mp[state]++;
}
swap(cur.a[pos], cur.a[pos - 1]);
}
if (pos % 3 != 2) {
swap(cur.a[pos], cur.a[pos + 1]);
state = cur.getState();
if (mp.count(state) == 0) {
cur.evaluate(fin);
pq.push(cur);
mp[state]++;
}
swap(cur.a[pos], cur.a[pos + 1]);
}
if (pos > 2) {
swap(cur.a[pos], cur.a[pos - 3]);
state = cur.getState();
if (mp.count(state) == 0) {
cur.evaluate(fin);
pq.push(cur);
mp[state]++;
}
swap(cur.a[pos], cur.a[pos - 3]);
}
if (pos < 6) {
swap(cur.a[pos], cur.a[pos + 3]);
state = cur.getState();
if (mp.count(state) == 0) {
cur.evaluate(fin);
pq.push(cur);
mp[state]++;
}
swap(cur.a[pos], cur.a[pos + 3]);
}
}
return 0;
}
这道题中 A* 算法的耗时也是个位数的。
总结
通过这样一道看似很简单的题目就可以发掘出很多衍生的问题和优化的策略,其实在刷算法题的过程中也有很多其他的题目也像这道题一样,虽然很简洁,但是又包含了很多知识点,这样的问题就很值得钻研。