嘛。去年考场上遇到这题的时候一脸懵逼……现在回过头来看就好多了,但是有一些细节还是要做清楚。 (/ω\)
题目:https://www.luogu.org/problemnew/show/2831
题目就是呢,要构造最少条的抛物线,消灭掉所有的小猪~然后求最少的抛物线数量~观察题目我们可以发现——特殊指令 m 是没有什么卵用的= =
然后接下来我们分析题目:求构造的最少抛物线数量 → 构造抛物线直到把所有的猪都打掉 → 每个猪打一次就没了,而且抛物线的轨迹不会改变 → 对于一条抛物线,如果能打掉尽可能多的猪那么再好不过 → n <= 18 → ……
好的,看到数据范围大家应该有一些想法了。大部分人会想到搜索 + 剪枝,也有写记忆化搜索的……然而这两种都不太好写,但是能解是肯定的,这里就不多说这个了。
n 才 18,还有另一种做法呀~那就是状态压缩 DP!( >ω<) 表示对于能理解状态压缩的玩家们,把这题写成状态压缩比写成搜索题好写多了~而且效率也很高的说。如果不能理解状压的话想写就比较麻烦辣。
那么我们要做什么呢?先设计个 DP 状态和方程呗,显然我们应该用dp[i]
来表示击败的猪的二进制状态为 i 的时候,需要构造的抛物线的最小数目。二进制状态为 i 是什么概念呢,这里简单地说一下,比如说总共有 3 只猪,这时候你把第一只(标号为 0)和第二只(标号为 1)的猪打掉了,那么 i 就是011
咯(打掉第几只猪,状态 i 的二进制下第几位就是 1),同理如果你把三只猪都打死了,那么 i 就是111
.
所以,根据状态压缩 DP 的原理,我们先枚举合法的每一个状态,也就是 i = 1
到 (1 << n) - 1
。对于每一个状态,随便举出其中包含的一个点 k,那么 dp[i]
最坏的情况应该是:在没打掉第 k 只猪之前的最小值 + 1,我们把这个值作为初值赋给 dp[i]
: dp[i] = dp[i & ~(1 << k)] + 1
, 或者表示成dp[i] = dp[i - (1 << k)] + 1
.
然后我们再枚举出当前状态包含的另一个点 t (t != k), 取dp[i]
(上文已经更新过了)和dp[j] + 1
的最小值即可。因为我们知道,两点构成一条抛物线,所以我们可以找两个,这两个点一定可以勾出一条抛物线,然后从当前状态 i 中去掉这条抛物线经过的所有的小猪,这样就得到了状态 j (╯‵□′)╯︵┴─┴ :dp[i] = min(dp[i], dp[i - (i & fstate[i][j])] + 1)
, 这里 fstate[i][j]
的意思请往下看。
那么还有一个问题,就是我们怎么知道 k 和 t 构成的那条抛物线经过了哪些猪呢?这时候我们可以枚举任意两个点算出经过这两个点的形如y=ax^2+bx, a < 0
的抛物线,再计算这条抛物线对其他点的影响,把它们整理成一个二进制状态;简言之,两点(当然,还有原点 0, 0)可以确定一条抛物线,我们用fstate[i][j]
表示能干掉第 i 和第 j 只小猪的这条抛物线,最终能打掉的猪的状态,这样上面的 j 就可以推出来了。fstate
数组的计算就是,选取两个点计算 a 和 b,然后枚举其他的点,如果aXk^2 + bXk = Yk
成立的话,那么就把这个点 k 合并到fstate[i][j]
里去:fstate[i][j] |= (1 << k)
.
好了。至此这道题已经做得差不多了,还有一个很关键的地方可能就是——抛物线怎么算 ( ̄▽ ̄)~这里主要要 care 一下精度的问题: const double eps = 1e-7;
.
然后我们对于两个已知点 A(x1, y1), B(x2, y2),我们可以得到两个等式:
y1 = a * x1^2 + b * x1
y2 = a * x2^2 + b * x2
移项变形一下,我们可以得到:
b = (y1 - a * x1^2) / x1
b = (y2 - a * x2^2) / x2
消去参数 b,然后得到关于 a 的一个等式,整理之后我们可以得到:
a = (x2 * y1 - x1 * y2) / (x1 * x2 * (x1 - x2))
其中 x1, x2, x1-x2 均不为 0. 得到 a 之后把 a 带回上面一个含 b 的式子就可以了(o ° ω ° O ) 。
然后这道题就做完啦。
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#define db double
using namespace std;
const int MAXN = 18;
const int INF = 1e9 + 7;
const db eps = 1e-7;
struct Point {
db x, y;
};
Point p[MAXN + 5];
int dp[1 << MAXN];
int fstate[MAXN][MAXN];
int n, m, t;
int main()
{
scanf("%d", &t);
while (t--)
{
scanf("%d%d", &n, &m);
for (int i = 0; i < n; i++)
{
scanf("%lf%lf", &p[i].x, &p[i].y);
}
memset(fstate, 0, sizeof fstate);
for (int i = 0; i < n; i++)
{
for (int j = 0; j < n; j++)
{
if (i == j) {
continue;
}
db x1 = p[i].x, x2 = p[j].x, y1 = p[i].y, y2 = p[j].y;
if (x1 < eps || x2 < eps || abs(x1 - x2) < eps) {
continue;
}
db a = (x2 * y1 - x1 * y2) / (x1 * x2 * (x1 - x2));
if (a > -eps) {
continue;
}
db b = y1 / x1 - a * x1;
int final = (1 << i) | (1 << j);
for (int k = 0; k < n; k++)
{
db xk = p[k].x, yk = p[k].y;
if (abs(a * xk * xk + b * xk - yk) < eps) {
final |= (1 << k);
}
}
fstate[i][j] = final;
}
}
fill(dp, dp + (1 << n), INF);
dp[0] = 0;
for (int i = 1; i < (1 << n); i++)
{
int cur = 0;
for (int j = 0; j < n; j++)
{
if (i & (1 << j)) {
cur = j;
break;
}
}
// dp[i] = dp[i - (1 << cur)] + 1;
dp[i] = dp[i & ~(1 << cur)] + 1;
for (int j = 0; j < n; j++)
{
if (i & (1 << j) && j != cur) {
dp[i] = min(dp[i], dp[i & ~(i & fstate[cur][j])] + 1);
// dp[i] = min(dp[i], dp[i - (i & fstate[cur][j])] + 1);
}
}
}
printf("%d\n", dp[(1 << n) - 1]);
}
return 0;
}