冬雨财经 发表于 2024-3-15 12:49:08

Rust实现线段树和懒标记

参考各家代码,用Rust实现了线段树和懒标记。
由于使用了泛型,很多操作都要用闭包自定义实现。
看代码。
// 线段树定义
pub struct SegmentTree<T: Clone>
{
    pub data: Vec<T>,
    tree: Vec<Option<T>>,
    marker: Vec<T>,                               //懒标记。
    query_op: Box<dyn Fn(T, T) -> T>, //查询时,对所有查询元素做的操作。比如加法,就是求区间的所有元素的和。
    marker_marker_op: Box<dyn Fn(T, T) -> T>, //marker加到marker上时,对marker的操作。通常我们要marker += p; 来更新标记,但是泛型实现不了,并且考虑到有些用户有别的需求,所以用闭包包装。
    marker_t_op: Box<dyn Fn(T, T) -> T>, //marker应用到T时,对T的操作。考虑到有些用户有别的需求,所以用闭包包装。
    marker_mul_usize: Box<dyn Fn(T, usize) -> T>, //marker乘usize的方法。这个没法通过要求满足Mul trait自动实现。由于使用了泛型,连乘法都要交给闭包实现。。。
}

impl<T: Clone + Default + Copy + PartialEq> SegmentTree<T> {
    pub fn new(
      data: Vec<T>,
      query_op: Box<dyn Fn(T, T) -> T>,
      marker_marker_op: Box<dyn Fn(T, T) -> T>,
      marker_t_op: Box<dyn Fn(T, T) -> T>,
      marker_mul_usize: Box<dyn Fn(T, usize) -> T>,
    ) -> Self {
      let data_len = data.len();
      let mut tr = Self {
            data,
            marker: vec!, //四倍原数据大小
            tree: vec!,         //四倍原数据大小
            query_op,
            marker_marker_op,
            marker_t_op,
            marker_mul_usize,
      };
      tr.build();
      tr
    }

    #
    pub fn get(&self, index: usize) -> Option<&T> {
      self.data.get(index)
    }

    #
    pub fn len(&self) -> usize {
      self.data.len()
    }

    #
    fn left_child(index: usize) -> usize {
      2 * index + 1
    }

    #
    fn right_child(index: usize) -> usize {
      2 * index + 2
    }

    #
    fn build(&mut self) {
      self.build_segment_tree(0, 0, self.data.len() - 1);
    }

    // 递归Build
    fn build_segment_tree(&mut self, tree_index: usize, left: usize, right: usize) {
      if left == right {
            self.tree = Some(self.data);
            return;
      }
      let left_tree_index = Self::left_child(tree_index);
      let right_tree_index = Self::right_child(tree_index);
      let mid = (right - left) / 2 + left;
      self.build_segment_tree(left_tree_index, left, mid);
      self.build_segment_tree(right_tree_index, mid + 1, right);
      // 左右子树数据处理方式
      if let Some(l) = self.tree {
            if let Some(r) = self.tree {
                self.tree = Some((self.query_op)(l, r))
            }
      }
    }

    // 返回对线段树的全部元素做query_op操作的结果
    #
    pub fn query_all(&mut self) -> T {
      self.recursion_query(0, self.data.len() - 1, 0, 0, self.data.len() - 1)
    }

    // 返回对线段树的范围全部元素做query_op操作的结果
    pub fn query(&mut self, l: usize, r: usize) -> Result<T, &'static str> {
      if l > self.data.len() || r > self.data.len() || l > r {
            return Err("索引错误");
      }
      if l == r {
            return Ok(self.data);
      }
      Ok(self.recursion_query(l, r, 0, 0, self.data.len() - 1))
    }

    // 在index表示的范围中查询值
    fn recursion_query(
      &mut self,
      l: usize,
      r: usize,
      index: usize,
      current_left: usize,
      current_right: usize,
    ) -> T {
      if l > current_right || r < current_left {
            return T::default();
      }
      if l == current_left && r == current_right {
            if let Some(d) = self.tree {
                if l == r {
                  self.data = d;
                }
                return d;
            }
            return T::default();
      }
      self.push_down(index, current_right - current_left + 1);
      let mid = current_left + (current_right - current_left) / 2;
      if l >= mid + 1 {
            return self.recursion_query(l, r, Self::right_child(index), mid + 1, current_right);
      } else if r <= mid {
            return self.recursion_query(l, r, Self::left_child(index), current_left, mid);
      }
      let l_res = self.recursion_query(l, mid, Self::left_child(index), current_left, mid);
      let r_res =
            self.recursion_query(mid + 1, r, Self::right_child(index), mid + 1, current_right);
      (self.query_op)(l_res, r_res)
    }

    // 更新index为val
    pub fn set(&mut self, index: usize, val: T) -> Result<(), &'static str> {
      if index >= self.data.len() {
            return Err("索引超过线段树长度");
      }
      // 更新数据
      self.data = val;
      // 递归更新树
      self.recursion_set(0, 0, self.data.len() - 1, index, val);
      Ok(())
    }

    // 递归更新树
    fn recursion_set(&mut self, index_tree: usize, l: usize, r: usize, index: usize, val: T) {
      if l == r {
            self.tree = Some(val);
            return;
      }
      let mid = l + (r - l) / 2;
      let left_child = Self::left_child(index_tree);
      let right_child = Self::right_child(index_tree);
      if index >= mid + 1 {
            self.recursion_set(right_child, mid + 1, r, index, val);
      } else {
            self.recursion_set(left_child, l, mid, index, val);
      }
      // 左右子树数据求和
      if let Some(l_d) = self.tree {
            if let Some(r_d) = self.tree {
                self.tree = Some((self.query_op)(l_d, r_d));
            }
      }
    }

    // 应用所有懒标记到data数组上
    #
    pub fn apply_marker_all(&mut self) {
      self.apply_marker_lr(0, self.data.len() - 1);
    }

    // 应用懒标记到数据范围
    #
    pub fn apply_marker_lr(&mut self, l: usize, r: usize) {
      self.apply_marker(l, r, 0, 0, self.data.len() - 1);
    }

    fn apply_marker(
      &mut self,
      l: usize,
      r: usize,
      index: usize,
      current_l: usize,
      current_r: usize,
    ) {
      if current_l > r || current_r < l || r >= self.data.len() {
            return; // 区间无交集
      } else {
            // 与目标区间有交集,但不包含于其中
            if current_l == current_r {
                if let Some(d) = self.tree {
                  self.data = d;
                }
                return;
            }
            let mid = (current_l + current_r) / 2;
            self.push_down(index, current_r - current_l + 1);
            self.apply_marker(l, r, Self::left_child(index), current_l, mid); // 递归地往下寻找
            self.apply_marker(l, r, Self::right_child(index), mid + 1, current_r);
            self.tree = Some((self.query_op)(
                self.tree.unwrap(),
                self.tree.unwrap(),
            ));
            // 根据子节点更新当前节点的值
      }
    }
    #
    pub fn update_interval(&mut self, l: usize, r: usize, delta: T) {
      self.update(l, r, delta, 0, 0, self.data.len() - 1);
    }

    // 传递marker到下级
    fn push_down(&mut self, index: usize, len: usize) {
      self.marker =
            (self.marker_marker_op)(self.marker, self.marker); // 标记向下传递
      self.marker =
            (self.marker_marker_op)(self.marker, self.marker);
      if self.tree.is_some() {
            self.tree = Some((self.marker_t_op)(
                (self.marker_mul_usize)(self.marker, len - (len / 2)),
                self.tree.unwrap(),
            ));
      }
      if self.tree.is_some() {
            self.tree = Some((self.marker_t_op)(
                (self.marker_mul_usize)(self.marker, len / 2),
                self.tree.unwrap(),
            ));
      }
      self.marker = T::default(); // 清除标记
    }

    fn update(
      &mut self,
      l: usize,
      r: usize,
      delta: T,
      index: usize,
      current_l: usize,
      current_r: usize,
    ) {
      if current_l > r || current_r < l {
            return; // 区间无交集
      } else if current_l >= l && current_r <= r {
            // 当前节点对应的区间包含在目标区间中
            if self.tree.is_some() {
                // 更新当前区间的值
                self.tree = Some((self.query_op)(
                  self.tree.unwrap(),
                  (self.marker_mul_usize)(delta, current_r - current_l + 1),
                ));
            }
            // 如果不是叶子节点
            if current_r > current_l {
                // 给当前区间打上标记
                self.marker = (self.marker_marker_op)(delta, self.marker);
            }
      } else {
            // 与目标区间有交集,但不包含于其中
            let mid = (current_l + current_r) / 2;
            self.push_down(index, current_r - current_l + 1);
            self.update(l, r, delta, Self::left_child(index), current_l, mid); // 递归地往下寻找
            self.update(l, r, delta, Self::right_child(index), mid + 1, current_r);
            self.tree = Some((self.query_op)(
                self.tree.unwrap(),
                self.tree.unwrap(),
            )); // 根据子节点更新当前节点的值
      }
    }
}

fn main() {
    let mut tr: SegmentTree<i32> = SegmentTree::new(
      vec!,
      Box::new(|a, b| a + b),
      Box::new(|a, b| a + b),
      Box::new(|a, b| a + b),
      Box::new(|a, b| a * (b as i32)),
    );
    let _ = tr.set(1, 2); //点更新,即把data设为2
    tr.update_interval(0, 2, -1); //区间更新,即每个元素减1
    tr.update_interval(1, 3, 2); //区间更新,即每个元素加2
    tr.apply_marker_all(); //应用全部marker到data数组
    println!("{}", tr.query_all()); //输出19,即全部元素的和
    println!("{:?}", tr.data); //输出
} 做一道题验证一下这个线段树的正确性,直接看我写的1589. 所有排列中的最大和题解即可(虽然这道题用差分数组最快,但是作为线段树验证还是很方便的)。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
页: [1]
查看完整版本: Rust实现线段树和懒标记