愤怒的小鸟 (Review / NOIP2016 D2T3)

嘛。去年考场上遇到这题的时候一脸懵逼……现在回过头来看就好多了,但是有一些细节还是要做清楚。 (/ω\)

题目:https://www.luogu.org/problemnew/show/2831

题目就是呢,要构造最少条的抛物线,消灭掉所有的小猪~然后求最少的抛物线数量~观察题目我们可以发现——特殊指令 m 是没有什么卵用的= =

然后接下来我们分析题目:求构造的最少抛物线数量 → 构造抛物线直到把所有的猪都打掉 → 每个猪打一次就没了,而且抛物线的轨迹不会改变 → 对于一条抛物线,如果能打掉尽可能多的猪那么再好不过 → n <= 18 → ……

好的,看到数据范围大家应该有一些想法了。大部分人会想到搜索 + 剪枝,也有写记忆化搜索的……然而这两种都不太好写,但是能解是肯定的,这里就不多说这个了。

n 才 18,还有另一种做法呀~那就是状态压缩 DP!( >ω<) 表示对于能理解状态压缩的玩家们,把这题写成状态压缩比写成搜索题好写多了~而且效率也很高的说。如果不能理解状压的话想写就比较麻烦辣。

那么我们要做什么呢?先设计个 DP 状态和方程呗,显然我们应该dp[i]来表示击败的猪的二进制状态为 i 的时候,需要构造的抛物线的最小数目。二进制状态为 i 是什么概念呢,这里简单地说一下,比如说总共有 3 只猪,这时候你把第一只(标号为 0)和第二只(标号为 1)的猪打掉了,那么 i 就是011咯(打掉第几只猪,状态 i 的二进制下第几位就是 1),同理如果你把三只猪都打死了,那么 i 就是111.

所以,根据状态压缩 DP 的原理,我们先枚举合法的每一个状态,也就是 i = 1(1 << n) - 1。对于每一个状态,随便举出其中包含的一个点 k,那么 dp[i] 最坏的情况应该是:在没打掉第 k 只猪之前的最小值 + 1,我们把这个值作为初值赋给 dp[i]: dp[i] = dp[i & ~(1 << k)] + 1, 或者表示成dp[i] = dp[i - (1 << k)] + 1.

然后我们再枚举出当前状态包含的另一个点 t (t != k), 取dp[i](上文已经更新过了)和dp[j] + 1的最小值即可。因为我们知道,两点构成一条抛物线,所以我们可以找两个,这两个点一定可以勾出一条抛物线,然后从当前状态 i 中去掉这条抛物线经过的所有的小猪,这样就得到了状态 j (╯‵□′)╯︵┴─┴ :dp[i] = min(dp[i], dp[i - (i & fstate[i][j])] + 1), 这里 fstate[i][j] 的意思请往下看。

那么还有一个问题,就是我们怎么知道 k 和 t 构成的那条抛物线经过了哪些猪呢?这时候我们可以枚举任意两个点算出经过这两个点的形如y=ax^2+bx, a < 0的抛物线,再计算这条抛物线对其他点的影响,把它们整理成一个二进制状态;简言之,两点(当然,还有原点 0, 0)可以确定一条抛物线,我们用fstate[i][j]表示能干掉第 i 和第 j 只小猪的这条抛物线,最终能打掉的猪的状态,这样上面的 j 就可以推出来了。fstate数组的计算就是,选取两个点计算 a 和 b,然后枚举其他的点,如果aXk^2 + bXk = Yk成立的话,那么就把这个点 k 合并到fstate[i][j]里去:fstate[i][j] |= (1 << k).

好了。至此这道题已经做得差不多了,还有一个很关键的地方可能就是——抛物线怎么算 ( ̄▽ ̄)~这里主要要 care 一下精度的问题: const double eps = 1e-7;.

然后我们对于两个已知点 A(x1, y1), B(x2, y2),我们可以得到两个等式:

y1 = a * x1^2 + b * x1
y2 = a * x2^2 + b * x2

移项变形一下,我们可以得到:

b = (y1 - a * x1^2) / x1
b = (y2 - a * x2^2) / x2

消去参数 b,然后得到关于 a 的一个等式,整理之后我们可以得到:

a = (x2 * y1 - x1 * y2) / (x1 * x2 * (x1 - x2))

其中 x1, x2, x1-x2 均不为 0. 得到 a 之后把 a 带回上面一个含 b 的式子就可以了(o ° ω ° O ) 。

然后这道题就做完啦。

#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
#define db double
using namespace std;
const int MAXN = 18;
const int INF = 1e9 + 7;
const db eps = 1e-7;
struct Point {
    db x, y;
}; 
Point p[MAXN + 5];
int dp[1 << MAXN];
int fstate[MAXN][MAXN];
int n, m, t;
int main()
{
    scanf("%d", &t);
    while (t--)
    {
        scanf("%d%d", &n, &m);
        for (int i = 0; i < n; i++)
        {
            scanf("%lf%lf", &p[i].x, &p[i].y);
        }
        
        memset(fstate, 0, sizeof fstate);
        
        for (int i = 0; i < n; i++)
        {
            for (int j = 0; j < n; j++)
            {
                if (i == j) {
                    continue;
                }
                
                db x1 = p[i].x, x2 = p[j].x, y1 = p[i].y, y2 = p[j].y;
                if (x1 < eps || x2 < eps || abs(x1 - x2) < eps) {
                    continue;
                }
                db a = (x2 * y1 - x1 * y2) / (x1 * x2 * (x1 - x2));
                if (a > -eps) {
                    continue;
                }

                db b = y1 / x1 - a * x1;
                int final = (1 << i) | (1 << j);
                for (int k = 0; k < n; k++)
                {
                    db xk = p[k].x, yk = p[k].y;
                    if (abs(a * xk * xk + b * xk - yk) < eps) {
                        final |= (1 << k);
                    }
                }
                fstate[i][j] = final;
            }
        }
    
        fill(dp, dp + (1 << n), INF);
        dp[0] = 0;
        for (int i = 1; i < (1 << n); i++)
        {
            int cur = 0;
            for (int j = 0; j < n; j++)
            {
                if (i & (1 << j)) {
                    cur = j;
                    break;
                }
            }
            
            // dp[i] = dp[i - (1 << cur)] + 1;
            dp[i] = dp[i & ~(1 << cur)] + 1;

            for (int j = 0; j < n; j++)
            {
               if (i & (1 << j) && j != cur) {
                    dp[i] = min(dp[i], dp[i & ~(i & fstate[cur][j])] + 1);
                    // dp[i] = min(dp[i], dp[i - (i & fstate[cur][j])] + 1);
               }
            }
        }
        
        printf("%d\n", dp[(1 << n) - 1]);
    }
    return 0;
}