「HDU-4616」Game (树形DP)

「HDU-4616」Game
树形DP,有限制的权值最大链问题

题意

给一棵有$n$个节点的树,每个节点有其$val$值和$trap$值(1表示该点有陷阱,0为无陷阱),最多可掉进陷阱$c$次。经过树上某一点时将取得该点的$val$值,踩到第$c$个陷阱后马上停止,且不能走已经走过的点。求最大能获得的$val$值之和。

解法

树形dp。

每个节点的状态可用dp[u][k][flag]表示,flag=1时,表示在以$u$为根的子树上,从一个有陷阱的起点经过$k$个陷阱走到$u$的最大值;若flag=0,则表示起点没有陷阱。

考虑最优解共走过$k$个陷阱,如果k!=c,那么起点和终点都可以为没有陷阱的点,如果k=c,那么起点和终点至少有一点为有陷阱点。

状态转移方程:

  1. 当$u$点的起点有陷阱时,在$0<k≤c$上,对$u$的子节点$v$有:

    $$dp[u][k+1][1]=max(dp[u][k+1][1],dp[v][k+1][1]+val[u])$$

  2. 当$u$点起点没有陷阱时,在$0≤k≤c$上,对$u$的子节点$v$有:

    $$dp[u][k][0]=max(dp[u][k][0],dp[v][k][0]+val[u])$$

    对于情况1,由于不存在总陷阱数为0且起点为陷阱的情况,故k!=0.

以上两式可合写为:

$$dp[u][k+trap[u]][flag]=max(dp[u][k+trap[u]][flag],dp[v][k][flag]+val[u])$$

对于最优解$ans$:

最优解可以看成两条链拼在一起,枚举两条链的陷阱个数并求和更新最优解。

更新最优解时需要注意:
  1. 当两条链的陷阱总数j+k=c时,两条链的起点不可能同时为0;
  2. 不存在总陷阱数为0且起点陷阱值为1的情况;

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;

const int maxn=5e4+10;
vector<int> tree[maxn];

int N,c;
int val[maxn],trap[maxn];
int dp[maxn][5][3];
int ans;

void dfs(int u,int fa)
{
dp[u][trap[u]][trap[u]]=val[u];
for(int i=0;i<tree[u].size();i++)
{
int v=tree[u][i];
if(v==fa) continue;
dfs(v,u);
for(int j=0;j<=c;j++)
{
for(int k=0;j+k<=c;k++)
{
ans=max(ans,dp[u][j][1]+dp[v][k][1]);
if(j+k<c) ans=max(ans,dp[u][j][0]+dp[v][k][0]);
if(k) ans=max(ans,dp[u][j][0]+dp[v][k][1]);
if(j) ans=max(ans,dp[u][j][1]+dp[v][k][0]);
}
}
for(int k=0;k+trap[u]<=c;k++)
{
dp[u][k+trap[u]][0]=max(dp[u][k+trap[u]][0],dp[v][k][0]+val[u]);
if(k) dp[u][k+trap[u]][1]=max(dp[u][k+trap[u]][1],dp[v][k][1]+val[u]);
}
}
}

int main()
{
int t,u,v;
scanf("%d",&t);
while(t--)
{
scanf("%d%d",&N,&c);
for(int i=0;i<N;i++)
tree[i].clear();
memset(dp,0,sizeof(dp));
for(int i=0;i<N;i++)
scanf("%d%d",&val[i],&trap[i]);
for(int i=0;i<N-1;i++)
{
scanf("%d%d",&u,&v);
tree[u].push_back(v);
tree[v].push_back(u);
}
ans=0;
dfs(0,-1);
printf("%d\n",ans);
}
}