Rust实现线段树和懒标记

打印 上一主题 下一主题

主题 863|帖子 863|积分 2589

参考各家代码,用Rust实现了线段树和懒标记。
由于使用了泛型,很多操作都要用闭包自定义实现。
看代码。
  1. // 线段树定义
  2. pub struct SegmentTree<T: Clone>
  3. {
  4.     pub data: Vec<T>,
  5.     tree: Vec<Option<T>>,
  6.     marker: Vec<T>,                               //懒标记。
  7.     query_op: Box<dyn Fn(T, T) -> T>, //查询时,对所有查询元素做的操作。比如加法,就是求区间的所有元素的和。
  8.     marker_marker_op: Box<dyn Fn(T, T) -> T>, //marker加到marker上时,对marker的操作。通常我们要marker[i] += p; 来更新标记,但是泛型实现不了,并且考虑到有些用户有别的需求,所以用闭包包装。
  9.     marker_t_op: Box<dyn Fn(T, T) -> T>, //marker应用到T时,对T的操作。考虑到有些用户有别的需求,所以用闭包包装。
  10.     marker_mul_usize: Box<dyn Fn(T, usize) -> T>, //marker乘usize的方法。这个没法通过要求满足Mul trait自动实现。由于使用了泛型,连乘法都要交给闭包实现。。。
  11. }
  12. impl<T: Clone + Default + Copy + PartialEq> SegmentTree<T> {
  13.     pub fn new(
  14.         data: Vec<T>,
  15.         query_op: Box<dyn Fn(T, T) -> T>,
  16.         marker_marker_op: Box<dyn Fn(T, T) -> T>,
  17.         marker_t_op: Box<dyn Fn(T, T) -> T>,
  18.         marker_mul_usize: Box<dyn Fn(T, usize) -> T>,
  19.     ) -> Self {
  20.         let data_len = data.len();
  21.         let mut tr = Self {
  22.             data,
  23.             marker: vec![T::default(); 4 * data_len], //四倍原数据大小
  24.             tree: vec![None; 4 * data_len],           //四倍原数据大小
  25.             query_op,
  26.             marker_marker_op,
  27.             marker_t_op,
  28.             marker_mul_usize,
  29.         };
  30.         tr.build();
  31.         tr
  32.     }
  33.     #[inline]
  34.     pub fn get(&self, index: usize) -> Option<&T> {
  35.         self.data.get(index)
  36.     }
  37.     #[inline]
  38.     pub fn len(&self) -> usize {
  39.         self.data.len()
  40.     }
  41.     #[inline]
  42.     fn left_child(index: usize) -> usize {
  43.         2 * index + 1
  44.     }
  45.     #[inline]
  46.     fn right_child(index: usize) -> usize {
  47.         2 * index + 2
  48.     }
  49.     #[inline]
  50.     fn build(&mut self) {
  51.         self.build_segment_tree(0, 0, self.data.len() - 1);
  52.     }
  53.     // 递归Build
  54.     fn build_segment_tree(&mut self, tree_index: usize, left: usize, right: usize) {
  55.         if left == right {
  56.             self.tree[tree_index] = Some(self.data[left]);
  57.             return;
  58.         }
  59.         let left_tree_index = Self::left_child(tree_index);
  60.         let right_tree_index = Self::right_child(tree_index);
  61.         let mid = (right - left) / 2 + left;
  62.         self.build_segment_tree(left_tree_index, left, mid);
  63.         self.build_segment_tree(right_tree_index, mid + 1, right);
  64.         // 左右子树数据处理方式
  65.         if let Some(l) = self.tree[left_tree_index] {
  66.             if let Some(r) = self.tree[right_tree_index] {
  67.                 self.tree[tree_index] = Some((self.query_op)(l, r))
  68.             }
  69.         }
  70.     }
  71.     // 返回对线段树的全部元素做query_op操作的结果
  72.     #[inline]
  73.     pub fn query_all(&mut self) -> T {
  74.         self.recursion_query(0, self.data.len() - 1, 0, 0, self.data.len() - 1)
  75.     }
  76.     // 返回对线段树的[l..r]范围全部元素做query_op操作的结果
  77.     pub fn query(&mut self, l: usize, r: usize) -> Result<T, &'static str> {
  78.         if l > self.data.len() || r > self.data.len() || l > r {
  79.             return Err("索引错误");
  80.         }
  81.         if l == r {
  82.             return Ok(self.data[l]);
  83.         }
  84.         Ok(self.recursion_query(l, r, 0, 0, self.data.len() - 1))
  85.     }
  86.     // 在index表示的[current_left,current_right]范围中查询[l..r]值
  87.     fn recursion_query(
  88.         &mut self,
  89.         l: usize,
  90.         r: usize,
  91.         index: usize,
  92.         current_left: usize,
  93.         current_right: usize,
  94.     ) -> T {
  95.         if l > current_right || r < current_left {
  96.             return T::default();
  97.         }
  98.         if l == current_left && r == current_right {
  99.             if let Some(d) = self.tree[index] {
  100.                 if l == r {
  101.                     self.data[l] = d;
  102.                 }
  103.                 return d;
  104.             }
  105.             return T::default();
  106.         }
  107.         self.push_down(index, current_right - current_left + 1);
  108.         let mid = current_left + (current_right - current_left) / 2;
  109.         if l >= mid + 1 {
  110.             return self.recursion_query(l, r, Self::right_child(index), mid + 1, current_right);
  111.         } else if r <= mid {
  112.             return self.recursion_query(l, r, Self::left_child(index), current_left, mid);
  113.         }
  114.         let l_res = self.recursion_query(l, mid, Self::left_child(index), current_left, mid);
  115.         let r_res =
  116.             self.recursion_query(mid + 1, r, Self::right_child(index), mid + 1, current_right);
  117.         (self.query_op)(l_res, r_res)
  118.     }
  119.     // 更新index为val
  120.     pub fn set(&mut self, index: usize, val: T) -> Result<(), &'static str> {
  121.         if index >= self.data.len() {
  122.             return Err("索引超过线段树长度");
  123.         }
  124.         // 更新数据
  125.         self.data[index] = val;
  126.         // 递归更新树
  127.         self.recursion_set(0, 0, self.data.len() - 1, index, val);
  128.         Ok(())
  129.     }
  130.     // 递归更新树
  131.     fn recursion_set(&mut self, index_tree: usize, l: usize, r: usize, index: usize, val: T) {
  132.         if l == r {
  133.             self.tree[index_tree] = Some(val);
  134.             return;
  135.         }
  136.         let mid = l + (r - l) / 2;
  137.         let left_child = Self::left_child(index_tree);
  138.         let right_child = Self::right_child(index_tree);
  139.         if index >= mid + 1 {
  140.             self.recursion_set(right_child, mid + 1, r, index, val);
  141.         } else {
  142.             self.recursion_set(left_child, l, mid, index, val);
  143.         }
  144.         // 左右子树数据求和
  145.         if let Some(l_d) = self.tree[left_child] {
  146.             if let Some(r_d) = self.tree[right_child] {
  147.                 self.tree[index_tree] = Some((self.query_op)(l_d, r_d));
  148.             }
  149.         }
  150.     }
  151.     // 应用所有懒标记到data数组上
  152.     #[inline]
  153.     pub fn apply_marker_all(&mut self) {
  154.         self.apply_marker_lr(0, self.data.len() - 1);
  155.     }
  156.     // 应用懒标记到[l:r]数据范围
  157.     #[inline]
  158.     pub fn apply_marker_lr(&mut self, l: usize, r: usize) {
  159.         self.apply_marker(l, r, 0, 0, self.data.len() - 1);
  160.     }
  161.     fn apply_marker(
  162.         &mut self,
  163.         l: usize,
  164.         r: usize,
  165.         index: usize,
  166.         current_l: usize,
  167.         current_r: usize,
  168.     ) {
  169.         if current_l > r || current_r < l || r >= self.data.len() {
  170.             return; // 区间无交集
  171.         } else {
  172.             // 与目标区间有交集,但不包含于其中
  173.             if current_l == current_r {
  174.                 if let Some(d) = self.tree[index] {
  175.                     self.data[current_l] = d;
  176.                 }
  177.                 return;
  178.             }
  179.             let mid = (current_l + current_r) / 2;
  180.             self.push_down(index, current_r - current_l + 1);
  181.             self.apply_marker(l, r, Self::left_child(index), current_l, mid); // 递归地往下寻找
  182.             self.apply_marker(l, r, Self::right_child(index), mid + 1, current_r);
  183.             self.tree[index] = Some((self.query_op)(
  184.                 self.tree[Self::left_child(index)].unwrap(),
  185.                 self.tree[Self::right_child(index)].unwrap(),
  186.             ));
  187.             // 根据子节点更新当前节点的值
  188.         }
  189.     }
  190.     #[inline]
  191.     pub fn update_interval(&mut self, l: usize, r: usize, delta: T) {
  192.         self.update(l, r, delta, 0, 0, self.data.len() - 1);
  193.     }
  194.     // 传递marker到下级
  195.     fn push_down(&mut self, index: usize, len: usize) {
  196.         self.marker[Self::left_child(index)] =
  197.             (self.marker_marker_op)(self.marker[index], self.marker[Self::left_child(index)]); // 标记向下传递
  198.         self.marker[Self::right_child(index)] =
  199.             (self.marker_marker_op)(self.marker[index], self.marker[Self::right_child(index)]);
  200.         if self.tree[Self::left_child(index)].is_some() {
  201.             self.tree[Self::left_child(index)] = Some((self.marker_t_op)(
  202.                 (self.marker_mul_usize)(self.marker[index], len - (len / 2)),
  203.                 self.tree[Self::left_child(index)].unwrap(),
  204.             ));
  205.         }
  206.         if self.tree[Self::right_child(index)].is_some() {
  207.             self.tree[Self::right_child(index)] = Some((self.marker_t_op)(
  208.                 (self.marker_mul_usize)(self.marker[index], len / 2),
  209.                 self.tree[Self::right_child(index)].unwrap(),
  210.             ));
  211.         }
  212.         self.marker[index] = T::default(); // 清除标记
  213.     }
  214.     fn update(
  215.         &mut self,
  216.         l: usize,
  217.         r: usize,
  218.         delta: T,
  219.         index: usize,
  220.         current_l: usize,
  221.         current_r: usize,
  222.     ) {
  223.         if current_l > r || current_r < l {
  224.             return; // 区间无交集
  225.         } else if current_l >= l && current_r <= r {
  226.             // 当前节点对应的区间包含在目标区间中
  227.             if self.tree[index].is_some() {
  228.                 // 更新当前区间的值
  229.                 self.tree[index] = Some((self.query_op)(
  230.                     self.tree[index].unwrap(),
  231.                     (self.marker_mul_usize)(delta, current_r - current_l + 1),
  232.                 ));
  233.             }
  234.             // 如果不是叶子节点
  235.             if current_r > current_l {
  236.                 // 给当前区间打上标记
  237.                 self.marker[index] = (self.marker_marker_op)(delta, self.marker[index]);
  238.             }
  239.         } else {
  240.             // 与目标区间有交集,但不包含于其中
  241.             let mid = (current_l + current_r) / 2;
  242.             self.push_down(index, current_r - current_l + 1);
  243.             self.update(l, r, delta, Self::left_child(index), current_l, mid); // 递归地往下寻找
  244.             self.update(l, r, delta, Self::right_child(index), mid + 1, current_r);
  245.             self.tree[index] = Some((self.query_op)(
  246.                 self.tree[Self::left_child(index)].unwrap(),
  247.                 self.tree[Self::right_child(index)].unwrap(),
  248.             )); // 根据子节点更新当前节点的值
  249.         }
  250.     }
  251. }
  252. fn main() {
  253.     let mut tr: SegmentTree<i32> = SegmentTree::new(
  254.         vec![1, 3, 4, 0, 0, 4, 5, 0],
  255.         Box::new(|a, b| a + b),
  256.         Box::new(|a, b| a + b),
  257.         Box::new(|a, b| a + b),
  258.         Box::new(|a, b| a * (b as i32)),
  259.     );
  260.     let _ = tr.set(1, 2); //点更新,即把data[1]设为2
  261.     tr.update_interval(0, 2, -1); //区间更新,即[0:2]每个元素减1
  262.     tr.update_interval(1, 3, 2); //区间更新,即[1:3]每个元素加2
  263.     tr.apply_marker_all(); //应用全部marker到data数组
  264.     println!("{}", tr.query_all()); //输出19,即全部元素的和
  265.     println!("{:?}", tr.data); //输出[0, 3, 5, 2, 0, 4, 5, 0]
  266. }
复制代码
做一道题验证一下这个线段树的正确性,直接看我写的1589. 所有排列中的最大和题解即可(虽然这道题用差分数组最快,但是作为线段树验证还是很方便的)。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

冬雨财经

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表