1010 Weak Pair
题解:树状数组/线段树 +离散化+dfs 对于每个a[i],将k/a[i]也放进去离散,这样的话对于每个数就能知道k/a[i]之前的有多少个数,从根节点开始dfs,然后找向它的子节点,看能不能匹配,而会造成影响的只有兄弟节点,那么在每次访问完一个节点的子节点之后删除这个点的贡献
线段树版本:
1 #include <iostream> 2 #include <stdio.h> 3 #include <string.h> 4 #include <algorithm> 5 #include <stdlib.h> 6 #include <vector> 7 #include <queue> 8 #include <stack> 9 #include <string> 10 using namespace std; 11 const int maxn=3*1e5+10; 12 #define MS(a,b) memset(a,b,sizeof(a)) 13 long long a[maxn]; 14 long long b[maxn]; 15 long long k; 16 int n,m; 17 vector<int> G[maxn]; 18 int in[maxn]; 19 long long ans; 20 struct node 21 { 22 int l,r; 23 int mid(){return (l+r)>>1;} 24 long long sum; 25 }; 26 node tree[maxn<<2]; 27 void pushup(int rt) 28 { 29 tree[rt].sum=tree[rt<<1].sum+tree[rt<<1|1].sum; 30 } 31 void build(int l,int r,int rt) 32 { 33 tree[rt].l=l,tree[rt].r=r; 34 tree[rt].sum=0; 35 if(l==r) return ; 36 int mid=tree[rt].mid(); 37 build(l,mid,rt<<1); 38 build(mid+1,r,rt<<1|1); 39 } 40 void update(int index,int x,int L,int R,int rt) 41 { 42 if(L==index&&R==index) 43 { 44 tree[rt].sum+=x; 45 return ; 46 } 47 int mid=tree[rt].mid(); 48 if(index<=mid) update(index,x,L,mid,rt<<1); 49 else if(index>mid) update(index,x,mid+1,R,rt<<1|1); 50 pushup(rt); 51 } 52 long long query(int l,int r,int L,int R,int rt) 53 { 54 if(L>=l&&R<=r) return tree[rt].sum; 55 int mid=tree[rt].mid(); 56 long long ans=0; 57 if(l<=mid) ans+=query(l,r,L,mid,rt<<1); 58 if(r>mid) ans+=query(l,r,mid+1,R,rt<<1|1); 59 return ans; 60 } 61 void dfs(int x) 62 { 63 int l=lower_bound(b+1,b+1+m,k/a[x])-b; 64 int pos=lower_bound(b+1,b+1+m,a[x])-b; 65 ans+=query(1,l,1,m,1); 66 update(pos,1,1,m,1); 67 for(int i=0;i<G[x].size();i++) dfs(G[x][i]); 68 update(pos,-1,1,m,1); 69 } 70 void init() 71 { 72 for(int i=0;i<=n+1;i++) G[i].clear(); 73 build(1,m,1); 74 MS(in,0); 75 ans=0; 76 } 77 int main() 78 { 79 int T; 80 scanf("%d",&T); 81 while(T--) 82 { 83 //init(); 84 scanf("%d %lld",&n,&k); 85 int cnt=1; 86 for(int i=1;i<=n;i++)scanf("%lld",&a[i]); 87 for(int i=1;i<=n;i++) 88 { 89 b[cnt]=a[i]; 90 cnt++; 91 b[cnt]=k/a[i]; 92 cnt++; 93 } 94 m=cnt-1; 95 sort(b+1,b+1+m); 96 init(); 97 for(int i=1;i<=n-1;i++) 98 { 99 int x,y; 100 scanf("%d %d",&x,&y); 101 G[x].push_back(y); 102 in[y]++; 103 } 104 int root; 105 for(int i=1;i<=n;i++) 106 { 107 if(in[i]==0) 108 { 109 root=i; 110 break; 111 } 112 } 113 dfs(root); 114 printf("%lld\n",ans); 115 116 } 117 return 0; 118 }