题解
2025-10-03 12:05:21
发布于:广东
5阅读
0回复
0点赞
#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
struct Node {
long long sum;
bool all_one;
Node() : sum(0), all_one(false) {}
};
class SegmentTree {
private:
vector<Node> tree;
vector<long long> arr;
int n;
void build(int node, int start, int end) {
if (start == end) {
tree[node].sum = arr[start];
tree[node].all_one = (arr[start] == 1);
return;
}
int mid = (start + end) / 2;
build(2 * node + 1, start, mid);
build(2 * node + 2, mid + 1, end);
tree[node].sum = tree[2 * node + 1].sum + tree[2 * node + 2].sum;
tree[node].all_one = tree[2 * node + 1].all_one && tree[2 * node + 2].all_one;
}
void update_range(int node, int start, int end, int l, int r) {
if (tree[node].all_one) {
return;
}
if (start == end) {
tree[node].sum = sqrt(tree[node].sum);
tree[node].all_one = (tree[node].sum == 1);
return;
}
int mid = (start + end) / 2;
if (l <= mid) {
update_range(2 * node + 1, start, mid, l, min(r, mid));
}
if (r > mid) {
update_range(2 * node + 2, mid + 1, end, max(l, mid + 1), r);
}
tree[node].sum = tree[2 * node + 1].sum + tree[2 * node + 2].sum;
tree[node].all_one = tree[2 * node + 1].all_one && tree[2 * node + 2].all_one;
}
long long query_range(int node, int start, int end, int l, int r) {
if (r < start || end < l) {
return 0;
}
if (l <= start && end <= r) {
return tree[node].sum;
}
int mid = (start + end) / 2;
long long left_sum = query_range(2 * node + 1, start, mid, l, r);
long long right_sum = query_range(2 * node + 2, mid + 1, end, l, r);
return left_sum + right_sum;
}
public:
SegmentTree(const vector<long long>& nums) {
arr = nums;
n = nums.size();
tree.resize(4 * n);
build(0, 0, n - 1);
}
void update(int l, int r) {
update_range(0, 0, n - 1, l, r);
}
long long query(int l, int r) {
return query_range(0, 0, n - 1, l, r);
}
};
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
int n;
cin >> n;
vector<long long> nums(n);
for (int i = 0; i < n; ++i) {
cin >> nums[i];
}
SegmentTree st(nums);
int m;
cin >> m;
while (m--) {
int k, l, r;
cin >> k >> l >> r;
if (l > r) {
swap(l, r);
}
l--; r--; // converting to 0-based index
if (k == 0) {
st.update(l, r);
} else {
cout << st.query(l, r) << '\n';
}
}
return 0;
}
这里空空如也







有帮助,赞一个