使用 Go 语言实现二叉搜索树

小小小幸运  金牌会员 | 2023-8-1 19:23:44 | 来自手机 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 869|帖子 869|积分 2607

原文链接: 使用 Go 语言实现二叉搜索树
二叉树是一种常见并且非常重要的数据结构,在很多项目中都能看到二叉树的身影。
它有很多变种,比如红黑树,常被用作 std::map 和 std::set 的底层实现;B 树和 B+ 树,广泛应用于数据库系统中。
本文要介绍的二叉搜索树用的也很多,比如在开源项目 go-zero 中,就被用来做路由管理。
这篇文章也算是一篇前导文章,介绍一些必备知识,下一篇再来介绍具体在 go-zero 中的应用。
二叉搜索树的特点

最重要的就是它的有序性,在二叉搜索树中,每个节点的值都大于其左子树中的所有节点的值,并且小于其右子树中的所有节点的值。

这意味着通过二叉搜索树可以快速实现对数据的查找和插入。
Go 语言实现

本文主要实现了以下几种方法:

  • Insert(t):插入一个节点
  • Search(t):判断节点是否在树中
  • InOrderTraverse():中序遍历
  • PreOrderTraverse():前序遍历
  • PostOrderTraverse():后序遍历
  • Min():返回最小值
  • Max():返回最大值
  • Remove(t):删除一个节点
  • String():打印一个树形结构
下面分别来介绍,首先定义一个节点:
  1. type Node struct {
  2.     key   int
  3.     value Item
  4.     left  *Node //left
  5.     right *Node //right
  6. }
复制代码
定义树的结构体,其中包含了锁,是线程安全的:
  1. type ItemBinarySearchTree struct {
  2.     root *Node
  3.     lock sync.RWMutex
  4. }
复制代码
插入操作:
  1. func (bst *ItemBinarySearchTree) Insert(key int, value Item) {
  2.     bst.lock.Lock()
  3.     defer bst.lock.Unlock()
  4.     n := &Node{key, value, nil, nil}
  5.     if bst.root == nil {
  6.         bst.root = n
  7.     } else {
  8.         insertNode(bst.root, n)
  9.     }
  10. }
  11. // internal function to find the correct place for a node in a tree
  12. func insertNode(node, newNode *Node) {
  13.     if newNode.key < node.key {
  14.         if node.left == nil {
  15.             node.left = newNode
  16.         } else {
  17.             insertNode(node.left, newNode)
  18.         }
  19.     } else {
  20.         if node.right == nil {
  21.             node.right = newNode
  22.         } else {
  23.             insertNode(node.right, newNode)
  24.         }
  25.     }
  26. }
复制代码
在插入时,需要判断插入节点和当前节点的大小关系,保证搜索树的有序性。
中序遍历:
  1. func (bst *ItemBinarySearchTree) InOrderTraverse(f func(Item)) {
  2.     bst.lock.RLock()
  3.     defer bst.lock.RUnlock()
  4.     inOrderTraverse(bst.root, f)
  5. }
  6. // internal recursive function to traverse in order
  7. func inOrderTraverse(n *Node, f func(Item)) {
  8.     if n != nil {
  9.         inOrderTraverse(n.left, f)
  10.         f(n.value)
  11.         inOrderTraverse(n.right, f)
  12.     }
  13. }
复制代码
前序遍历:
  1. func (bst *ItemBinarySearchTree) PreOrderTraverse(f func(Item)) {
  2.     bst.lock.Lock()
  3.     defer bst.lock.Unlock()
  4.     preOrderTraverse(bst.root, f)
  5. }
  6. // internal recursive function to traverse pre order
  7. func preOrderTraverse(n *Node, f func(Item)) {
  8.     if n != nil {
  9.         f(n.value)
  10.         preOrderTraverse(n.left, f)
  11.         preOrderTraverse(n.right, f)
  12.     }
  13. }
复制代码
后序遍历:
  1. func (bst *ItemBinarySearchTree) PostOrderTraverse(f func(Item)) {
  2.     bst.lock.Lock()
  3.     defer bst.lock.Unlock()
  4.     postOrderTraverse(bst.root, f)
  5. }
  6. // internal recursive function to traverse post order
  7. func postOrderTraverse(n *Node, f func(Item)) {
  8.     if n != nil {
  9.         postOrderTraverse(n.left, f)
  10.         postOrderTraverse(n.right, f)
  11.         f(n.value)
  12.     }
  13. }
复制代码
返回最小值:
  1. func (bst *ItemBinarySearchTree) Min() *Item {
  2.     bst.lock.RLock()
  3.     defer bst.lock.RUnlock()
  4.     n := bst.root
  5.     if n == nil {
  6.         return nil
  7.     }
  8.     for {
  9.         if n.left == nil {
  10.             return &n.value
  11.         }
  12.         n = n.left
  13.     }
  14. }
复制代码
由于树的有序性,想要得到最小值,一直向左查找就可以了。
返回最大值:
  1. func (bst *ItemBinarySearchTree) Max() *Item {
  2.     bst.lock.RLock()
  3.     defer bst.lock.RUnlock()
  4.     n := bst.root
  5.     if n == nil {
  6.         return nil
  7.     }
  8.     for {
  9.         if n.right == nil {
  10.             return &n.value
  11.         }
  12.         n = n.right
  13.     }
  14. }
复制代码
查找节点是否存在:
  1. func (bst *ItemBinarySearchTree) Search(key int) bool {
  2.     bst.lock.RLock()
  3.     defer bst.lock.RUnlock()
  4.     return search(bst.root, key)
  5. }
  6. // internal recursive function to search an item in the tree
  7. func search(n *Node, key int) bool {
  8.     if n == nil {
  9.         return false
  10.     }
  11.     if key < n.key {
  12.         return search(n.left, key)
  13.     }
  14.     if key > n.key {
  15.         return search(n.right, key)
  16.     }
  17.     return true
  18. }
复制代码
删除节点:
  1. func (bst *ItemBinarySearchTree) Remove(key int) {
  2.     bst.lock.Lock()
  3.     defer bst.lock.Unlock()
  4.     remove(bst.root, key)
  5. }
  6. // internal recursive function to remove an item
  7. func remove(node *Node, key int) *Node {
  8.     if node == nil {
  9.         return nil
  10.     }
  11.     if key < node.key {
  12.         node.left = remove(node.left, key)
  13.         return node
  14.     }
  15.     if key > node.key {
  16.         node.right = remove(node.right, key)
  17.         return node
  18.     }
  19.     // key == node.key
  20.     if node.left == nil && node.right == nil {
  21.         node = nil
  22.         return nil
  23.     }
  24.     if node.left == nil {
  25.         node = node.right
  26.         return node
  27.     }
  28.     if node.right == nil {
  29.         node = node.left
  30.         return node
  31.     }
  32.     leftmostrightside := node.right
  33.     for {
  34.         //find smallest value on the right side
  35.         if leftmostrightside != nil && leftmostrightside.left != nil {
  36.             leftmostrightside = leftmostrightside.left
  37.         } else {
  38.             break
  39.         }
  40.     }
  41.     node.key, node.value = leftmostrightside.key, leftmostrightside.value
  42.     node.right = remove(node.right, node.key)
  43.     return node
  44. }
复制代码
删除操作会复杂一些,分三种情况来考虑:

  • 如果要删除的节点没有子节点,只需要直接将父节点中,指向要删除的节点指针置为 nil 即可
  • 如果删除的节点只有一个子节点,只需要更新父节点中,指向要删除节点的指针,让它指向删除节点的子节点即可
  • 如果删除的节点有两个子节点,我们需要找到这个节点右子树中的最小节点,把它替换到要删除的节点上。然后再删除这个最小节点,因为最小节点肯定没有左子节点,所以可以应用第二种情况删除这个最小节点即可
最后是一个打印树形结构的方法,在实际项目中其实并没有实际作用:
  1. func (bst *ItemBinarySearchTree) String() {
  2.     bst.lock.Lock()
  3.     defer bst.lock.Unlock()
  4.     fmt.Println("------------------------------------------------")
  5.     stringify(bst.root, 0)
  6.     fmt.Println("------------------------------------------------")
  7. }
  8. // internal recursive function to print a tree
  9. func stringify(n *Node, level int) {
  10.     if n != nil {
  11.         format := ""
  12.         for i := 0; i < level; i++ {
  13.             format += "       "
  14.         }
  15.         format += "---[ "
  16.         level++
  17.         stringify(n.left, level)
  18.         fmt.Printf(format+"%d\n", n.key)
  19.         stringify(n.right, level)
  20.     }
  21. }
复制代码
单元测试

下面是一段测试代码:
  1. func fillTree(bst *ItemBinarySearchTree) {
  2.     bst.Insert(8, "8")
  3.     bst.Insert(4, "4")
  4.     bst.Insert(10, "10")
  5.     bst.Insert(2, "2")
  6.     bst.Insert(6, "6")
  7.     bst.Insert(1, "1")
  8.     bst.Insert(3, "3")
  9.     bst.Insert(5, "5")
  10.     bst.Insert(7, "7")
  11.     bst.Insert(9, "9")
  12. }
  13. func TestInsert(t *testing.T) {
  14.     fillTree(&bst)
  15.     bst.String()
  16.     bst.Insert(11, "11")
  17.     bst.String()
  18. }
  19. // isSameSlice returns true if the 2 slices are identical
  20. func isSameSlice(a, b []string) bool {
  21.     if a == nil && b == nil {
  22.         return true
  23.     }
  24.     if a == nil || b == nil {
  25.         return false
  26.     }
  27.     if len(a) != len(b) {
  28.         return false
  29.     }
  30.     for i := range a {
  31.         if a[i] != b[i] {
  32.             return false
  33.         }
  34.     }
  35.     return true
  36. }
  37. func TestInOrderTraverse(t *testing.T) {
  38.     var result []string
  39.     bst.InOrderTraverse(func(i Item) {
  40.         result = append(result, fmt.Sprintf("%s", i))
  41.     })
  42.     if !isSameSlice(result, []string{"1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11"}) {
  43.         t.Errorf("Traversal order incorrect, got %v", result)
  44.     }
  45. }
  46. func TestPreOrderTraverse(t *testing.T) {
  47.     var result []string
  48.     bst.PreOrderTraverse(func(i Item) {
  49.         result = append(result, fmt.Sprintf("%s", i))
  50.     })
  51.     if !isSameSlice(result, []string{"8", "4", "2", "1", "3", "6", "5", "7", "10", "9", "11"}) {
  52.         t.Errorf("Traversal order incorrect, got %v instead of %v", result, []string{"8", "4", "2", "1", "3", "6", "5", "7", "10", "9", "11"})
  53.     }
  54. }
  55. func TestPostOrderTraverse(t *testing.T) {
  56.     var result []string
  57.     bst.PostOrderTraverse(func(i Item) {
  58.         result = append(result, fmt.Sprintf("%s", i))
  59.     })
  60.     if !isSameSlice(result, []string{"1", "3", "2", "5", "7", "6", "4", "9", "11", "10", "8"}) {
  61.         t.Errorf("Traversal order incorrect, got %v instead of %v", result, []string{"1", "3", "2", "5", "7", "6", "4", "9", "11", "10", "8"})
  62.     }
  63. }
  64. func TestMin(t *testing.T) {
  65.     if fmt.Sprintf("%s", *bst.Min()) != "1" {
  66.         t.Errorf("min should be 1")
  67.     }
  68. }
  69. func TestMax(t *testing.T) {
  70.     if fmt.Sprintf("%s", *bst.Max()) != "11" {
  71.         t.Errorf("max should be 11")
  72.     }
  73. }
  74. func TestSearch(t *testing.T) {
  75.     if !bst.Search(1) || !bst.Search(8) || !bst.Search(11) {
  76.         t.Errorf("search not working")
  77.     }
  78. }
  79. func TestRemove(t *testing.T) {
  80.     bst.Remove(1)
  81.     if fmt.Sprintf("%s", *bst.Min()) != "2" {
  82.         t.Errorf("min should be 2")
  83.     }
  84. }
复制代码
上文中的全部源码都是经过测试的,可以直接运行,并且已经上传到了 GitHub,需要的同学可以自取。
以上就是本文的全部内容,如果觉得还不错的话欢迎点赞转发关注,感谢支持。
源码地址:
推荐阅读:
参考文章:

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

小小小幸运

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表