本文讲解的是golang.org/x/sync这个包中的errgroup
1、errgroup 的基础介绍
学习过 Go 的朋友都知道 Go 实现并发编程是比较容易的事情,只需要使用go关键字就可以开启一个 goroutine。那对于并发场景中,如何实现goroutine的协调控制呢?常见的一种方式是使用sync.WaitGroup 来进行协调控制。
使用过sync.WaitGroup 的朋友知道,sync.WaitGroup 虽然可以实现协调控制,但是不能传递错误,那该如何解决呢?聪明的你可能马上想到使用 chan 或者是 context来传递错误,确实是可以的。那接下来,我们一起看看官方是怎么实现上面的需求的呢?
1.1 errgroup的安装
安装命令:- go get golang.org/x/sync
- //下面的案例是基于v0.1.0 演示的
- go get golang.org/x/sync@v0.1.0
复制代码 1.2 errgroup的基础例子
这里我们需要请求3个url来获取数据,假设请求url2时报错,url3耗时比较久,需要等一秒。- package main
- import (
- "errors"
- "fmt"
- "golang.org/x/sync/errgroup"
- "strings"
- "time"
- )
- func main() {
- queryUrls := map[string]string{
- "url1": "http://localhost/url1",
- "url2": "http://localhost/url2",
- "url3": "http://localhost/url3",
- }
- var eg errgroup.Group
- var results []string
- for _, url := range queryUrls {
- url := url
- eg.Go(func() error {
- result, err := query(url)
- if err != nil {
- return err
- }
- results = append(results, fmt.Sprintf("url:%s -- ret: %v", url, result))
- return nil
- })
- }
-
- // group 的wait方法,等待上面的 eg.Go 的协程执行完成,并且可以接受错误
- err := eg.Wait()
- if err != nil {
- fmt.Println("eg.Wait error:", err)
- return
- }
- for k, v := range results {
- fmt.Printf("%v ---> %v\n", k, v)
- }
- }
- func query(url string) (ret string, err error) {
- // 假设这里是发送请求,获取数据
- if strings.Contains(url, "url2") {
- // 假设请求 url2 时出现错误
- fmt.Printf("请求 %s 中....\n", url)
- return "", errors.New("请求超时")
- } else if strings.Contains(url, "url3") {
- // 假设 请求 url3 需要1秒
- time.Sleep(time.Second*1)
- }
- fmt.Printf("请求 %s 中....\n", url)
- return "success", nil
- }
复制代码 执行结果:- 请求 http://localhost/url2 中....
- 请求 http://localhost/url1 中....
- 请求 http://localhost/url3 中....
- eg.Wait error: 请求超时
复制代码 果然,当其中一个goroutine出现错误时,会把goroutine中的错误传递出来。
我们自己运行一下上面的代码就会发现这样一个问题,请求 url2 出错了,但是依旧在请求 url3 。因为我们需要聚合 url1、url2、url3 的结果,所以当其中一个出现问题时,我们是可以做一个优化的,就是当其中一个出现错误时,取消还在执行的任务,直接返回结果,不用等待任务执行结果。
那应该如何做呢?
这里假设 url1 执行1秒,url2 执行报错,url3执行3秒。所以当url2报错后,就不用等url3执行结束就可以返回了。
[code]package mainimport ( "context" "errors" "fmt" "golang.org/x/sync/errgroup" "strings" "time")func main() { queryUrls := map[string]string{ "url1": "http://localhost/url1", "url2": "http://localhost/url2", "url3": "http://localhost/url3", } var results []string ctx, cancel := context.WithCancel(context.Background()) eg, errCtx := errgroup.WithContext(ctx) for _, url := range queryUrls { url := url eg.Go(func() error { result, err := query(errCtx, url) if err != nil { //其实这里不用手动取消,看完源码就知道为啥了 cancel() return err } results = append(results, fmt.Sprintf("url:%s -- ret: %v", url, result)) return nil }) } err := eg.Wait() if err != nil { fmt.Println("eg.Wait error:", err) return } for k, v := range results { fmt.Printf("%v ---> %v\n", k, v) }}func query(errCtx context.Context, url string) (ret string, err error) { fmt.Printf("请求 %s 开始....\n", url) // 假设这里是发送请求,获取数据 if strings.Contains(url, "url2") { // 假设请求 url2 时出现错误 time.Sleep(time.Second*2) return "", errors.New("请求出错") } else if strings.Contains(url, "url3") { // 假设 请求 url3 需要1秒 select { case |