划分区间|RMQ线段树/树状数组优化DP
2024-12-03 23:05:49
发布于:美国
T6 - 划分区间
题目链接跳转:点击跳转
一道线段树优化动态规划的题目,难度趋近于 CSP 提高组的题目和 USACO 铂金组的中等题。一眼可以看出题目是一个典型的动态规划问题,但奈何数据量太大了, 的复杂度肯定会 TLE。但无论如何都是 “车到山前必有路”,看到数据范围不用怕,先打一个暴力的动态规划再优化。
按照一位 OI 大神的说法:“所有的动态规划优化都是在基础的代码上等量代换”。
与打家劫舍等线性动态规划类似,对于本题而言,设状态的定义为 表示对 这个序列划分后可得到的最大贡献。通过暴力遍历 ,表示将 归位一组。另设 为区间 的贡献值。根据以上信息可以得到状态转移方程:
接下来就是关于 的计算了。设前缀和数组 表示从区间 的和,那么 区间的和可以被表示为 。根据不同的 ,则有以下三种情况:
- 当 时,证明该区间的和是正数,贡献为 。
- 当 时,该区间的和为零,贡献为 。
- 当 时,证明该区间的和是负数,贡献为 。
综上所述,可以写出一个暴力版本的动态规划代码:
#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
using namespace std;
int main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
int n;
cin >> n;
vector<int> A(n + 1);
vector<long long> S(n + 1, 0);
for (int i = 1; i <= n; i++) {
cin >> A[i];
S[i] = S[i - 1] + A[i];
}
vector<long long> dp(n + 1, LLONG_MIN);
dp[0] = 0;
for (int i = 1; i <= n; i++) {
for (int j = 0; j < i; j++) {
if (S[i] - S[j] > 0)
dp[i] = max(dp[i], dp[j] + (i - j));
if (S[i] - S[j] < 0)
dp[i] = max(dp[i], dp[j] - (i - j));
if (S[i] - S[j] == 0)
dp[i] = max(dp[i], dp[j]);
}
}
cout << dp[n] << endl;
return 0;
}
接下来考虑优化这个动态规划,注意到每一次寻找 都非常耗时,每一次都需要遍历一遍才能求出最大值。有没有一种方法可以快速求出某一个区间的最大值呢?答案就是线段树。线段树是一个非常好的快速求解区间最值问题的数据结构。
更多有关区间最值问题的学习请参考:[# 浅入线段树与区间最值问题](# 浅入线段树与区间最值问题)
综上,我们可以通过构建线段树来快速求得答案。简化三种情况可得:
if (S[i] - S[j] > 0)
dp[i] = max(dp[i], dp[j] - j + i);
if (S[i] - S[j] < 0)
dp[i] = max(dp[i], dp[j] + j - i));
if (S[i] - S[j] == 0)
dp[i] = max(dp[i], dp[j]);
因此我们构造三棵线段树,分别来维护这三个区间:
然而我们的线段树不能仅仅维护这个区间,因为这三个的最大值还被 的三种状态所限制着,因此,我们需要找的是满足 在特定条件下的最大值。这样就出现了另一个严重的问题, 的值可能非常的大,因此我们需要对前缀和数组离散化一下(坐标压缩:类似于权值线段树的写法)才可以防止内存超限。
这样子对于每次寻找最大值,都可以在 的情况下找到。本算法的总时间复杂度也控制在了 级别。
本题的 C++ 代码如下:
#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
#define int long long
using namespace std;
constexpr int INF = 0x7f7f7f7f7f7f7f7f;
const int MAX = 500005;
int n, A[MAX], sum[MAX];
int discretized[MAX];
struct SegmentTree {
int size; vector<int> tree;
SegmentTree(int n_) {
size = 1;
while (size < n_) size <<= 1;
tree.assign(2 * size, -INF);
}
void update(int pos, int value) {
pos += size - 1;
tree[pos] = max(tree[pos], value);
while (pos > 1) {
pos >>= 1;
tree[pos] = max(tree[2 * pos], tree[2 * pos + 1]);
}
}
int query(int l, int r) {
l += size - 1; r += size - 1;
int res = -INF;
while (l <= r) {
if (l % 2 == 1)
res = max(res, tree[l++]);
if (r % 2 == 0)
res = max(res, tree[r--]);
l >>= 1; r >>= 1;
}
return res;
}
};
int get_id(int x, int* arr, int m) {
return lower_bound(arr, arr + m, x) - arr + 1;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++)
cin >> A[i];
sum[0] = 0;
for (int i = 1; i <= n; i++)
sum[i] = sum[i - 1] + A[i];
for (int i = 0; i <= n; i++)
discretized[i] = sum[i];
sort(discretized, discretized + n + 1);
int m = unique(discretized, discretized + n + 1)
- discretized;
SegmentTree Tree1(m); // max(dp[j] - j)
SegmentTree Tree2(m); // max(dp[j])
SegmentTree Tree3(m); // max(dp[j] + j)
int idx_sum0 = get_id(sum[0], discretized, m);
Tree1.update(idx_sum0, 0);
Tree2.update(idx_sum0, 0);
Tree3.update(idx_sum0, 0);
int current = -INF;
for (int i = 1; i <= n; i++) {
int index = get_id(sum[i], discretized, m);
int p1 = -INF;
if (index > 1) {
int temp = Tree1.query(1, index - 1);
if (temp != -INF) {
p1 = temp + i;
}
}
int p2 = Tree2.query(index, index);
int p3 = -INF;
if (index < m) {
int temp = Tree3.query(index + 1, m);
if (temp != -INF) {
p3 = temp - i;
}
}
current = max(p1, max(p2, p3));
Tree1.update(index, current - i);
Tree2.update(index, current);
Tree3.update(index, current + i);
}
cout << current << endl;
}
本题的 Python 代码如下(由于 Python 常数过大,因此没有办法通过这道题所有的测试点,但是代码的正确性没有问题):
class SegmentTree:
def __init__(self, n):
self.size = 1
while self.size < n:
self.size *= 2
self.tree = [float('-inf')] * (2 * self.size)
def update(self, pos, value):
pos += self.size - 1
self.tree[pos] = max(self.tree[pos], value)
while pos > 1:
pos //= 2
self.tree[pos] = max(self.tree[2 * pos], self.tree[2 * pos + 1])
def query(self, l, r):
l += self.size - 1
r += self.size - 1
res = float('-inf')
while l <= r:
if l % 2 == 1:
res = max(res, self.tree[l])
l += 1
if r % 2 == 0:
res = max(res, self.tree[r])
r -= 1
l //= 2
r //= 2
return res
def main():
import sys
input = sys.stdin.read
data = input().split()
n = int(data[0])
A = list(map(int, data[1:n + 1]))
S = [0] * (n + 1)
for i in range(1, n + 1):
S[i] = S[i - 1] + A[i - 1]
aintS_arr = S[:]
aintS_arr.sort()
m = len(set(aintS_arr))
aintS_arr = sorted(set(aintS_arr))
def get_idx(x):
# Return the index in the compressed array
return aintS_arr.index(x) + 1
BIT1 = SegmentTree(m) # max(dp[j] - j)
BIT2 = SegmentTree(m) # max(dp[j])
BIT3 = SegmentTree(m) # max(dp[j] + j)
idx_S0 = get_idx(S[0])
BIT1.update(idx_S0, 0)
BIT2.update(idx_S0, 0)
BIT3.update(idx_S0, 0)
dp_i = float('-inf')
for i in range(1, n + 1):
Si = S[i]
idx_Si = get_idx(Si)
option1 = float('-inf')
if idx_Si > 1:
temp = BIT1.query(1, idx_Si - 1)
if temp != float('-inf'):
option1 = temp + i
option2 = BIT2.query(idx_Si, idx_Si)
option3 = float('-inf')
if idx_Si < m:
temp = BIT3.query(idx_Si + 1, m)
if temp != float('-inf'):
option3 = temp - i
dp_i = max(option1, option2, option3)
BIT1.update(idx_Si, dp_i - i)
BIT2.update(idx_Si, dp_i)
BIT3.update(idx_Si, dp_i + i)
print(dp_i)
if __name__ == "__main__":
main()
当然也可以用树状数组来写,速度可能会更快一点:
#include <iostream>
#include <vector>
#include <algorithm>
#include <climits>
#define int long long
using namespace std;
constexpr int INF = 0x7f7f7f7f7f7f7f7f;
const int MAX = 500005;
int n, A[MAX], sum[MAX];
int discretized[MAX];
// 获取离散化后的索引
int get_id(int x, int* arr, int m) {
return lower_bound(arr, arr + m, x) - arr + 1;
}
// 树状数组(BIT)实现前缀最大值
struct BIT {
int size;
vector<int> tree;
BIT(int n_) : size(n_), tree(n_ + 1, -INF) {}
void update_bit(int idx, int val) {
while(idx <= size){
if(val > tree[idx]) tree[idx] = val;
else break;
idx += idx & (-idx);
}
}
int query_bit(int idx){
int res = -INF;
while(idx > 0){
res = max(res, tree[idx]);
idx -= idx & (-idx);
}
return res;
}
};
signed main() {
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i++)
cin >> A[i];
sum[0] = 0;
for (int i = 1; i <= n; i++)
sum[i] = sum[i - 1] + A[i];
for (int i = 0; i <= n; i++)
discretized[i] = sum[i];
sort(discretized, discretized + n + 1);
int m = unique(discretized, discretized + n + 1)
- discretized;
// 初始化三颗树状数组
BIT Tree1(m); // max(dp[j] - j)
BIT Tree3(m); // max(dp[j] + j)
// Tree2 作为单独的数组存储每个位置的最大值
vector<int> Tree2(m + 1, -INF); // max(dp[j])
int idx_sum0 = get_id(sum[0], discretized, m);
Tree1.update_bit(idx_sum0, 0);
Tree3.update_bit(m - idx_sum0 +1, 0);
Tree2[idx_sum0] = max(Tree2[idx_sum0], (int)0);
int current = -INF;
for (int i = 1; i <= n; i++) {
int Si = sum[i];
int idx_Si = get_id(Si, discretized, m);
int p1 = -INF;
if (idx_Si > 1) {
int temp = Tree1.query_bit(idx_Si -1);
if (temp != -INF) {
p1 = temp + i;
}
}
int p2 = Tree2[idx_Si];
int p3 = -INF;
if (idx_Si < m) {
int reversed_idx = m - (idx_Si +1) +1;
int temp = Tree3.query_bit(reversed_idx);
if (temp != -INF) {
p3 = temp - i;
}
}
current = max(p1, max(p2, p3));
// 更新 Tree1
Tree1.update_bit(idx_Si, current - i);
// 更新 Tree2
Tree2[idx_Si] = max(Tree2[idx_Si], current);
// 更新 Tree3
int reversed_update_idx = m - idx_Si +1;
Tree3.update_bit(reversed_update_idx, current + i);
}
cout << current << endl;
}
更新日志:
Dec. 3, 2024:优化了代码的变量名,T6 增加了树状数组解法。
全部评论 3
rmq可以用笛卡尔树解决吗
2024-12-04 来自 江苏
1笛卡尔树比较慢,当然用笛卡尔树也是可以的。经过测试,平均需要 800ms 通过所有的测试点。但使用树状数组/线段树的 RMQ 可以做到平均在 150ms 内通过所有的测试点。
2024-12-04 来自 加拿大
1可以给你提供一个笛卡尔树的代码:
struct TreapNode { int key; // sum[j] int priority; TreapNode* left; TreapNode* right; // 当前节点的值 int dp_j_minus_j; int dp_j; int dp_j_plus_j; // 子树中的最大值 int subtree_max_dp_j_minus_j; int subtree_max_dp_j; int subtree_max_dp_j_plus_j; TreapNode(int k, int dp_val, int index) : key(k), priority(rand()), left(nullptr), right(nullptr) { dp_j_minus_j = dp_val - index; dp_j = dp_val; dp_j_plus_j = dp_val + index; subtree_max_dp_j_minus_j = dp_j_minus_j; subtree_max_dp_j = dp_j; subtree_max_dp_j_plus_j = dp_j_plus_j; } // 更新子树的最大值 void update() { subtree_max_dp_j_minus_j = dp_j_minus_j; subtree_max_dp_j = dp_j; subtree_max_dp_j_plus_j = dp_j_plus_j; if (left) { subtree_max_dp_j_minus_j = max(subtree_max_dp_j_minus_j, left->subtree_max_dp_j_minus_j); subtree_max_dp_j = max(subtree_max_dp_j, left->subtree_max_dp_j); subtree_max_dp_j_plus_j = max(subtree_max_dp_j_plus_j, left->subtree_max_dp_j_plus_j); } if (right) { subtree_max_dp_j_minus_j = max(subtree_max_dp_j_minus_j, right->subtree_max_dp_j_minus_j); subtree_max_dp_j = max(subtree_max_dp_j, right->subtree_max_dp_j); subtree_max_dp_j_plus_j = max(subtree_max_dp_j_plus_j, right->subtree_max_dp_j_plus_j); } } };
2024-12-04 来自 加拿大
1struct Treap { TreapNode* root; Treap() : root(nullptr) {} // 分割 treap,使得左子树所有 key < key_val,右子树所有 key >= key_val pair<TreapNode*, TreapNode*> split(TreapNode* node, int key_val) { if (!node) return {nullptr, nullptr}; if (key_val > node->key) { auto split_res = split(node->right, key_val); node->right = split_res.first; node->update(); return {node, split_res.second}; } else { auto split_res = split(node->left, key_val); node->left = split_res.second; node->update(); return {split_res.first, node}; } } // 合并两个 treap,所有 keys in left <= keys in right TreapNode* merge(TreapNode* left, TreapNode* right) { if (!left || !right) return left ? left : right; if (left->priority > right->priority) { left->right = merge(left->right, right); left->update(); return left; } else { right->left = merge(left, right->left); right->update(); return right; } } // 插入一个新的节点 void insert(int key, int dp_val, int index) { TreapNode* new_node = new TreapNode(key, dp_val, index); if (!root) { root = new_node; return; } pair<TreapNode*, TreapNode*> split_res = split(root, key); // 处理相同 key 的情况,插入多个相同 key 节点 root = merge(merge(split_res.first, new_node), split_res.second); }
2024-12-04 来自 加拿大
1
6啊
2024-12-03 来自 广东
0比较恶心的 RMQ 问题
2024-12-03 来自 美国
0没学过🤣
2024-12-04 来自 广东
0
顶
2024-12-03 来自 加拿大
0
有帮助,赞一个