[golang]在Gin框架中使用JWT鉴权

打印 上一主题 下一主题

主题 913|帖子 913|积分 2739

什么是JWT

JWT,全称 JSON Web Token,是一种开放标准(RFC 7519),用于安全地在双方之间通报信息。尤其适用于身份验证和授权场景。JWT 的设计允许信息在各方之间安全地、 compactly(紧凑地)传输,因为其自身包含了所有需要的认证信息,从而淘汰了需要查询数据库或会话存储的需求。
JWT主要由三部门组成,通过.毗连:

  • Header(头部):描述JWT的元数据,通常包括类型(通常是JWT)和使用的签名算法(如HS256、RS256等)。
  • Payload(载荷):包含声明(claims),即用户的相关信息。这些信息可以是公开的,也可以是私有的,但应避免放入敏感信息,因为该部门可以被解码查看。载荷中的声明可以验证,但不加密。
  • Signature(签名):用于验证JWT的完整性和来源。它是通过将Header和Payload分别进行Base64编码后,再与一个秘钥(secret)一起通过指定的算法(如HMAC SHA256)盘算得出的。
JWT的工作流程大致如下:

  • 认证阶段:用户向服务器提供凭据(如用户名和密码)。服务器验证凭据无误后,生成一个JWT,其中包含用户标识符和其他声明,并使用秘钥对其进行签名。
  • 使用阶段:客户端收到JWT后,可以在后续的每个请求中将其放在HTTP请求头中发送给服务器,以此证实自己的身份。
  • 验证阶段:服务器收到JWT后,会使用相同的秘钥验证JWT的签名,确保其未被窜改,并检查过期时间等其他声明,从而决定是否允许实行请求。
JWT的优势在于它的无状态性,服务器不需要存储会话信息,这减轻了服务器的压力,同时也方便了跨域认证。但需要留意的是,JWT的安全性依赖于秘钥的安全保管以及对JWT过期时间等的合理设置。
API设计

这里设计两个公共接口和一个受保护的接口。
API描述/api/login公开接口。用于用户登录/api/register公开接口。用于用户注册/api/admin/user保护接口,需要验证JWT开发预备

初始化项目目次并切换进入
  1. mkdir gin-jwt
  2. cd gin-jwt
复制代码
使用go mod初始化工程
  1. go mod init gin-jwt
复制代码
安装依赖
  1. go get -u github.com/gin-gonic/gin
  2. go get -u gorm.io/gorm
  3. go get -u gorm.io/driver/postgres
  4. go get -u github.com/golang-jwt/jwt/v5
  5. go get -u github.com/joho/godotenv
  6. go get -u golang.org/x/crypto
复制代码
创建第一个API

一开始我们可以在项目的根目次中创建文件main.go
  1. touch main.go
复制代码
添加以下内容
  1. package main
  2. import (
  3.         "net/http"
  4.         "github.com/gin-gonic/gin"
  5. )
  6. func main() {
  7.         r := gin.Default()
  8.         public := r.Group("/api")
  9.         {
  10.                 public.POST("/register", func(c *gin.Context) {
  11.                         c.JSON(http.StatusOK, gin.H{
  12.                                 "data": "test. register api",
  13.                         })
  14.                 })
  15.         }
  16.         r.Run("0.0.0.0:8000")
  17. }
复制代码
测试运行
  1. go run main.go
复制代码
客户端测试。正常的话会有以下输出
  1. $ curl -X POST http://127.0.0.1:8000/api/register
  2. {"data":"test. register api"}
复制代码
美满register接口

如今register接口已经预备好了,但一样平常来说我们会把接口业务逻辑放在单独的文件中,而不是和接口定义写在一块。
创建一个控制器的包目次,并添加文件
  1. mkdir controllers
  2. touch controllers/auth.go
复制代码
auth.go文件内容
  1. package controllers
  2. import (
  3.         "net/http"
  4.         "github.com/gin-gonic/gin"
  5. )
  6. func Register(c *gin.Context) {
  7.         c.JSON(http.StatusOK, gin.H{
  8.                 "data": "hello, this is register endpoint",
  9.         })
  10. }
复制代码
更新main.go文件
  1. package main
  2. import (
  3.         "github.com/gin-gonic/gin"
  4.         "gin-jwt/controllers"
  5. )
  6. func main() {
  7.         r := gin.Default()
  8.         public := r.Group("/api")
  9.         {
  10.                 public.POST("/register", controllers.Register)
  11.         }
  12.         r.Run("0.0.0.0:8000")
  13. }
复制代码
重新运行测试
  1. go run main.go
复制代码
客户端测试
  1. $ curl -X POST http://127.0.0.1:8000/api/register
  2. {"data":"hello, this is register endpoint"}
复制代码
解析register的客户端请求

客户端请求register api需要携带用户名和密码的参数,服务端对此做解析。编辑文件controllers/auth.go
  1. package controllers
  2. import (
  3.         "net/http"
  4.         "github.com/gin-gonic/gin"
  5. )
  6. // /api/register的请求体
  7. type ReqRegister struct {
  8.         Username string `json:"username" binding:"required"`
  9.         Password string `json:"password" binding:"required"`
  10. }
  11. func Register(c *gin.Context) {
  12.         var req ReqRegister
  13.         if err := c.ShouldBindBodyWithJSON(&req); err != nil {
  14.                 c.JSON(http.StatusBadRequest, gin.H{
  15.                         "data": err.Error(),
  16.                 })
  17.                 return
  18.         }
  19.         c.JSON(http.StatusOK, gin.H{
  20.                 "data": req,
  21.         })
  22. }
复制代码
客户端请求测试
  1. $ curl -X POST http://127.0.0.1:8000/api/register -d '{"username": "zhangsan", "password": "123456"}' -H 'Content-Type=application/json'
  2. {"data":{"username":"zhangsan","password":"123456"}}
复制代码
毗连关系型数据库

一样平常会将数据保存到专门的数据库中,这里用PostgreSQL来存储数据。Postgres使用docker来安装。安装完postgres后,创建用户和数据库:
  1. create user ginjwt encrypted password 'ginjwt';
  2. create database ginjwt owner = ginjwt;
复制代码
创建目次models,这个目次将包含毗连数据库和数据模型的代码。
  1. mkdir models
复制代码
编辑文件models/setup.go
  1. package models
  2. import (
  3.         "fmt"
  4.         "log"
  5.         "os"
  6.         "github.com/joho/godotenv"
  7.         "gorm.io/driver/postgres"
  8.         "gorm.io/gorm"
  9. )
  10. var DB *gorm.DB
  11. func ConnectDatabase() {
  12.         err := godotenv.Load(".env")
  13.         if err != nil {
  14.                 log.Fatalf("Error loading .env file. %v\n", err)
  15.         }
  16.         // DbDriver := os.Getenv("DB_DRIVER")
  17.         DbHost := os.Getenv("DB_HOST")
  18.         DbPort := os.Getenv("DB_PORT")
  19.         DbUser := os.Getenv("DB_USER")
  20.         DbPass := os.Getenv("DB_PASS")
  21.         DbName := os.Getenv("DB_NAME")
  22.         dsn := fmt.Sprintf("host=%s port=%s user=%s dbname=%s sslmode=disable TimeZone=Asia/Shanghai password=%s", DbHost, DbPort, DbUser, DbName, DbPass)
  23.         DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
  24.         if err != nil {
  25.                 log.Fatalf("Connect to database failed, %v\n", err)
  26.         } else {
  27.                 log.Printf("Connect to database success, host: %s, port: %s, user: %s, dbname: %s\n", DbHost, DbPort, DbUser, DbName)
  28.         }
  29.         // 迁移数据表
  30.         DB.AutoMigrate(&User{})
  31. }
复制代码
新建并编辑环境设置文件.env
  1. DB_HOST=127.0.0.1
  2. DB_PORT=5432
  3. DB_USER=ginjwt
  4. DB_PASS=ginjwt
  5. DB_NAME=ginjwt
复制代码
创建用户模型,编辑代码文件models/user.go
  1. package models
  2. import (
  3.         "html"
  4.         "strings"
  5.         "golang.org/x/crypto/bcrypt"
  6.         "gorm.io/gorm"
  7. )
  8. type User struct {
  9.         gorm.Model
  10.         Username string `gorm:"size:255;not null;unique" json:"username"`
  11.         Password string `gorm:"size:255;not null;" json:"password"`
  12. }
  13. func (u *User) SaveUser() (*User, error) {
  14.         err := DB.Create(&u).Error
  15.         if err != nil {
  16.                 return &User{}, err
  17.         }
  18.         return u, nil
  19. }
  20. // 使用gorm的hook在保存密码前对密码进行hash
  21. func (u *User) BeforeSave(tx *gorm.DB) error {
  22.         hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
  23.         if err != nil {
  24.                 return err
  25.         }
  26.         u.Password = string(hashedPassword)
  27.         u.Username = html.EscapeString(strings.TrimSpace(u.Username))
  28.         return nil
  29. }
复制代码
更新main.go
  1. package main
  2. import (
  3.         "github.com/gin-gonic/gin"
  4.         "gin-jwt/controllers"
  5.         "gin-jwt/models"
  6. )
  7. func init() {
  8.         models.ConnectDatabase()
  9. }
  10. func main() {
  11.         r := gin.Default()
  12.         public := r.Group("/api")
  13.         {
  14.                 public.POST("/register", controllers.Register)
  15.         }
  16.         r.Run("0.0.0.0:8000")
  17. }
复制代码
更新controllers/auth.go
  1. package controllers
  2. import (
  3.         "net/http"
  4.         "gin-jwt/models"
  5.         "github.com/gin-gonic/gin"
  6. )
  7. // /api/register的请求体
  8. type ReqRegister struct {
  9.         Username string `json:"username" binding:"required"`
  10.         Password string `json:"password" binding:"required"`
  11. }
  12. func Register(c *gin.Context) {
  13.         var req ReqRegister
  14.         if err := c.ShouldBindBodyWithJSON(&req); err != nil {
  15.                 c.JSON(http.StatusBadRequest, gin.H{
  16.                         "data": err.Error(),
  17.                 })
  18.                 return
  19.         }
  20.         u := models.User{
  21.                 Username: req.Username,
  22.                 Password: req.Password,
  23.         }
  24.         _, err := u.SaveUser()
  25.         if err != nil {
  26.                 c.JSON(http.StatusBadRequest, gin.H{
  27.                         "data": err.Error(),
  28.                 })
  29.                 return
  30.         }
  31.         c.JSON(http.StatusOK, gin.H{
  32.                 "message": "register success",
  33.                 "data":    req,
  34.         })
  35. }
复制代码
重新运行服务端后,客户端测试
  1. $ curl -X POST http://127.0.0.1:8000/api/register -d '{"username": "zhangsan", "password": "123456"}' -H 'Content-Type=application/json'
  2. {"data":{"username":"zhangsan","password":"123456"},"message":"register success"}
复制代码
添加login接口

登录接口实现的也非常简单,只需要提供用户名和密码参数。服务端接收到客户端的请求后到数据库中去匹配,确认用户是否存在和密码是否正确。假如验证通过则返回一个token,否则返回异常响应。
首先在main.go中注册API
  1. // xxx
  2. func main() {
  3.         // xxx
  4.         r := gin.Default()
  5.         public := r.Group("/api")
  6.         {
  7.                 public.POST("/register", controllers.Register)
  8.                 public.POST("/login", controllers.Login)
  9.         }
  10. }
复制代码
在auth.go中添加Login控制器函数
  1. // api/login 的请求体
  2. type ReqLogin struct {
  3.         Username string `json:"username" binding:"required"`
  4.         Password string `json:"password" binding:"required"`
  5. }
  6. func Login(c *gin.Context) {
  7.         var req ReqLogin
  8.         if err := c.ShouldBindBodyWithJSON(&req); err != nil {
  9.                 c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  10.                 return
  11.         }
  12.         u := models.User{
  13.                 Username: req.Username,
  14.                 Password: req.Password,
  15.         }
  16.         // 调用 models.LoginCheck 对用户名和密码进行验证
  17.         token, err := models.LoginCheck(u.Username, u.Password)
  18.         if err != nil {
  19.                 c.JSON(http.StatusBadRequest, gin.H{
  20.                         "error": "username or password is incorrect.",
  21.                 })
  22.                 return
  23.         }
  24.         c.JSON(http.StatusOK, gin.H{
  25.                 "token": token,
  26.         })
  27. }
复制代码
LoginCheck方法在models/user.go文件中实现
  1. package models
  2. import (
  3.         "gin-jwt/utils/token"
  4.         "html"
  5.         "strings"
  6.         "golang.org/x/crypto/bcrypt"
  7.         "gorm.io/gorm"
  8. )
  9. func VerifyPassword(password, hashedPassword string) error {
  10.         return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
  11. }
  12. func LoginCheck(username, password string) (string, error) {
  13.         var err error
  14.         u := User{}
  15.         err = DB.Model(User{}).Where("username = ?", username).Take(&u).Error
  16.         if err != nil {
  17.                 return "", err
  18.         }
  19.         err = VerifyPassword(password, u.Password)
  20.         if err != nil && err == bcrypt.ErrMismatchedHashAndPassword {
  21.                 return "", err
  22.         }
  23.         token, err := token.GenerateToken(u.ID)
  24.         if err != nil {
  25.                 return "", err
  26.         }
  27.         return token, nil
  28. }
复制代码
这里将token相关的函数放到了单独的模块中,新增相关目次并编辑文件
  1. mkdir -p utils/token
  2. touch utils/token/token.go
复制代码
以下代码为token.go的内容,包含的几个函数在反面会用到
  1. package token
  2. import (
  3.         "fmt"
  4.         "os"
  5.         "strconv"
  6.         "strings"
  7.         "time"
  8.         "github.com/gin-gonic/gin"
  9.         "github.com/golang-jwt/jwt/v5"
  10. )
  11. func GenerateToken(user_id uint) (string, error) {
  12.         token_lifespan, err := strconv.Atoi(os.Getenv("TOKEN_HOUR_LIFESPAN"))
  13.         if err != nil {
  14.                 return "", err
  15.         }
  16.         claims := jwt.MapClaims{}
  17.         claims["authorized"] = true
  18.         claims["user_id"] = user_id
  19.         claims["exp"] = time.Now().Add(time.Hour * time.Duration(token_lifespan)).Unix()
  20.         token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  21.         return token.SignedString([]byte(os.Getenv("API_SECRET")))
  22. }
  23. func TokenValid(c *gin.Context) error {
  24.         tokenString := ExtractToken(c)
  25.         fmt.Println(tokenString)
  26.         _, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
  27.                 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  28.                         return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
  29.                 }
  30.                 return []byte(os.Getenv("API_SECRET")), nil
  31.         })
  32.         if err != nil {
  33.                 return err
  34.         }
  35.         return nil
  36. }
  37. // 从请求头中获取token
  38. func ExtractToken(c *gin.Context) string {
  39.         bearerToken := c.GetHeader("Authorization")
  40.         if len(strings.Split(bearerToken, " ")) == 2 {
  41.                 return strings.Split(bearerToken, " ")[1]
  42.         }
  43.         return ""
  44. }
  45. // 从jwt中解析出user_id
  46. func ExtractTokenID(c *gin.Context) (uint, error) {
  47.         tokenString := ExtractToken(c)
  48.         token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
  49.                 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  50.                         return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
  51.                 }
  52.                 return []byte(os.Getenv("API_SECRET")), nil
  53.         })
  54.         if err != nil {
  55.                 return 0, err
  56.         }
  57.         claims, ok := token.Claims.(jwt.MapClaims)
  58.         // 如果jwt有效,将user_id转换为浮点数字符串,然后再转换为 uint32
  59.         if ok && token.Valid {
  60.                 uid, err := strconv.ParseUint(fmt.Sprintf("%.0f", claims["user_id"]), 10, 32)
  61.                 if err != nil {
  62.                         return 0, err
  63.                 }
  64.                 return uint(uid), nil
  65.         }
  66.         return 0, nil
  67. }
复制代码
在.env文件中添加两个环境变量的设置。TOKEN_HOUR_LIFESPAN设置token的过期时长,API_SECRET是jwt的密钥。
  1. TOKEN_HOUR_LIFESPAN=1
  2. API_SECRET="wP3-sN6&gG4-lV8>gJ9)"
复制代码
测试,这里改用python代码进行测试
  1. import requests
  2. import json
  3. headers = {
  4.     "Content-Type": "application/json",
  5. }
  6. resp = requests.get("http://127.0.0.1:8000/api/admin/user", headers=headers)
  7. def register(username: str, password: str):
  8.     req_body = {
  9.         "username": username,
  10.         "password": password,
  11.     }
  12.     resp = requests.post("http://127.0.0.1:8000/api/register", data=json.dumps(req_body), headers=headers)
  13.     print(resp.text)
  14. def login(username: str, password: str):
  15.     req_body = {
  16.         "username": username,
  17.         "password": password,
  18.     }
  19.     resp = requests.post("http://127.0.0.1:8000/api/login", data=json.dumps(req_body), headers=headers)
  20.     print(resp.text)
  21.     if resp.status_code == 200:
  22.         return resp.json()["token"]
  23.     else:
  24.         return ""
  25. if __name__ == "__main__":
  26.     username = "lisi"
  27.     password = "123456"
  28.     register(username, password)
  29.     token = login(username, password)
  30.         print(token)
复制代码
创建JWT认证中间件

创建中间件目次和代码文件
  1. mkdir middlewares
  2. touch middlewares/middlewares.go
复制代码
内容如下
  1. package middlewares
  2. import (
  3.         "gin-jwt/utils/token"
  4.         "net/http"
  5.         "github.com/gin-gonic/gin"
  6. )
  7. func JwtAuthMiddleware() gin.HandlerFunc {
  8.         return func(c *gin.Context) {
  9.                 err := token.TokenValid(c)
  10.                 if err != nil {
  11.                         c.String(http.StatusUnauthorized, err.Error())
  12.                         c.Abort()
  13.                         return
  14.                 }
  15.                 c.Next()
  16.         }
  17. }
复制代码
在main.go文件中注册路由的时候使用中间件
  1. func main() {
  2.         models.ConnectDatabase()
  3.         r := gin.Default()
  4.         public := r.Group("/api")
  5.         {
  6.                 public.POST("/register", controllers.Register)
  7.                 public.POST("/login", controllers.Login)
  8.         }
  9.         protected := r.Group("/api/admin")
  10.         {
  11.                 protected.Use(middlewares.JwtAuthMiddleware())
  12.                 protected.GET("/user", func(c *gin.Context) {
  13.                         c.JSON(http.StatusOK, gin.H{
  14.                                 "status":  "success",
  15.                                 "message": "authorized",
  16.                         })
  17.                 })
  18.         }
  19.         r.Run("0.0.0.0:8000")
  20. }
复制代码
在controllers/auth.go文件中实现CurrentUser
  1. func CurrentUser(c *gin.Context) {
  2.         // 从token中解析出user_id
  3.         user_id, err := token.ExtractTokenID(c)
  4.         if err != nil {
  5.                 c.JSON(http.StatusBadRequest, gin.H{
  6.                         "error": err.Error(),
  7.                 })
  8.                 return
  9.         }
  10.         // 根据user_id从数据库查询数据
  11.         u, err := models.GetUserByID(user_id)
  12.         if err != nil {
  13.                 c.JSON(http.StatusBadRequest, gin.H{
  14.                         "error": err.Error(),
  15.                 })
  16.                 return
  17.         }
  18.         c.JSON(http.StatusOK, gin.H{
  19.                 "message": "success",
  20.                 "data": u,
  21.         })
  22. }
复制代码
在models/user.go文件中实现GetUserByID
  1. // 返回前将用户密码置空
  2. func (u *User) PrepareGive() {
  3.         u.Password = ""
  4. }
  5. func GetUserByID(uid uint) (User, error) {
  6.         var u User
  7.         if err := DB.First(&u, uid).Error; err != nil {
  8.                 return u, errors.New("user not found")
  9.         }
  10.         u.PrepareGive()
  11.         return u, nil
  12. }
复制代码
至此,一个简单的gin-jwt应用就完成了。
客户端测试python脚本

服务端的三个接口这里用python脚本来测试
  1. import requests
  2. import json
  3. headers = {
  4.     # "Authorization": f"Bearer {token}",
  5.     "Content-Type": "application/json",
  6. }
  7. resp = requests.get("http://127.0.0.1:8000/api/admin/user", headers=headers)
  8. def register(username: str, password: str):
  9.     req_body = {
  10.         "username": username,
  11.         "password": password,
  12.     }
  13.     resp = requests.post("http://127.0.0.1:8000/api/register", data=json.dumps(req_body), headers=headers)
  14.     print(resp.text)
  15. def login(username: str, password: str):
  16.     req_body = {
  17.         "username": username,
  18.         "password": password,
  19.     }
  20.     resp = requests.post("http://127.0.0.1:8000/api/login", data=json.dumps(req_body), headers=headers)
  21.     print(resp.text)
  22.     if resp.status_code == 200:
  23.         return resp.json()["token"]
  24.     else:
  25.         return ""
  26. def test_protect_api(token: str):
  27.     global headers
  28.     headers["Authorization"] = f"Bearer {token}"
  29.     resp = requests.get("http://127.0.0.1:8000/api/admin/user", headers=headers)
  30.     print(resp.text)
  31. if __name__ == "__main__":
  32.     username = "lisi"
  33.     password = "123456"
  34.     register(username, password)
  35.     token = login(username, password)
  36.     test_protect_api(token)
复制代码
运行脚本效果
  1. {"message":"register success"}
  2. {"token":"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhdXRob3JpemVkIjp0cnVlLCJleHAiOjE3MTk5NDA0NjAsInVzZXJfaWQiOjZ9.qkzn0Ot9hAb54l3RFbGUohHJ9oezGia5x_oXppbD2jQ"}
  3. {"data":{"ID":6,"CreatedAt":"2024-07-03T00:14:20.187725+08:00","UpdatedAt":"2024-07-03T00:14:20.187725+08:00","DeletedAt":null,"username":"wangwu","password":""},"message":"success"}
复制代码
完整示例代码

目次结构
  1. ├── client.py  # 客户端测试脚本
  2. ├── controllers  # 控制器相关包
  3. │   └── auth.go  # 控制器方法实现
  4. ├── gin-jwt.bin  # 编译的二进制文件
  5. ├── go.mod  # go 项目文件
  6. ├── go.sum  # go 项目文件
  7. ├── main.go  # 程序入口文件
  8. ├── middlewares  # 中间件相关包
  9. │   └── middlewares.go  # 中间件代码文件
  10. ├── models  # 存储层相关包
  11. │   ├── setup.go  # 配置数据库连接
  12. │   └── user.go  # user模块相关数据交互的代码文件
  13. ├── README.md  # git repo的描述文件
  14. └── utils  # 工具类包
  15.     └── token  # token相关工具类包
  16.         └── token.go  # token工具的代码文件
复制代码
main.go
  1. package main
  2. import (
  3.         "log"
  4.         "github.com/gin-gonic/gin"
  5.         "gin-jwt/controllers"
  6.         "gin-jwt/middlewares"
  7.         "gin-jwt/models"
  8.         "github.com/joho/godotenv"
  9. )
  10. func init() {
  11.         err := godotenv.Load(".env")
  12.         if err != nil {
  13.                 log.Fatalf("Error loading .env file. %v\n", err)
  14.         }
  15. }
  16. func main() {
  17.         models.ConnectDatabase()
  18.         r := gin.Default()
  19.         public := r.Group("/api")
  20.         {
  21.                 public.POST("/register", controllers.Register)
  22.                 public.POST("/login", controllers.Login)
  23.         }
  24.         protected := r.Group("/api/admin")
  25.         {
  26.                 protected.Use(middlewares.JwtAuthMiddleware()) // 在路由组中使用中间件
  27.                 protected.GET("/user", controllers.CurrentUser)
  28.         }
  29.         r.Run("0.0.0.0:8000")
  30. }
复制代码
controllers


  • auth.go
  1. package controllersimport (        "net/http"        "gin-jwt/models"        "gin-jwt/utils/token"        "github.com/gin-gonic/gin")// /api/register的请求体type ReqRegister struct {        Username string `json:"username" binding:"required"`        Password string `json:"password" binding:"required"`}// api/login 的请求体
  2. type ReqLogin struct {
  3.         Username string `json:"username" binding:"required"`
  4.         Password string `json:"password" binding:"required"`
  5. }
  6. func Login(c *gin.Context) {
  7.         var req ReqLogin
  8.         if err := c.ShouldBindBodyWithJSON(&req); err != nil {
  9.                 c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
  10.                 return
  11.         }
  12.         u := models.User{
  13.                 Username: req.Username,
  14.                 Password: req.Password,
  15.         }
  16.         // 调用 models.LoginCheck 对用户名和密码进行验证
  17.         token, err := models.LoginCheck(u.Username, u.Password)
  18.         if err != nil {
  19.                 c.JSON(http.StatusBadRequest, gin.H{
  20.                         "error": "username or password is incorrect.",
  21.                 })
  22.                 return
  23.         }
  24.         c.JSON(http.StatusOK, gin.H{
  25.                 "token": token,
  26.         })
  27. }func Register(c *gin.Context) {        var req ReqRegister        if err := c.ShouldBindBodyWithJSON(&req); err != nil {                c.JSON(http.StatusBadRequest, gin.H{                        "data": err.Error(),                })                return        }        u := models.User{                Username: req.Username,                Password: req.Password,        }        _, err := u.SaveUser()        if err != nil {                c.JSON(http.StatusBadRequest, gin.H{                        "data": err.Error(),                })                return        }        c.JSON(http.StatusOK, gin.H{                "message": "register success",        })}func CurrentUser(c *gin.Context) {
  28.         // 从token中解析出user_id
  29.         user_id, err := token.ExtractTokenID(c)
  30.         if err != nil {
  31.                 c.JSON(http.StatusBadRequest, gin.H{
  32.                         "error": err.Error(),
  33.                 })
  34.                 return
  35.         }
  36.         // 根据user_id从数据库查询数据
  37.         u, err := models.GetUserByID(user_id)
  38.         if err != nil {
  39.                 c.JSON(http.StatusBadRequest, gin.H{
  40.                         "error": err.Error(),
  41.                 })
  42.                 return
  43.         }
  44.         c.JSON(http.StatusOK, gin.H{
  45.                 "message": "success",
  46.                 "data": u,
  47.         })
  48. }
复制代码
models


  • setup.go
  1. package models
  2. import (
  3.         "fmt"
  4.         "log"
  5.         "os"
  6.         "gorm.io/driver/postgres"
  7.         "gorm.io/gorm"
  8. )
  9. var DB *gorm.DB
  10. func ConnectDatabase() {
  11.         var err error
  12.         DbHost := os.Getenv("DB_HOST")
  13.         DbPort := os.Getenv("DB_PORT")
  14.         DbUser := os.Getenv("DB_USER")
  15.         DbPass := os.Getenv("DB_PASS")
  16.         DbName := os.Getenv("DB_NAME")
  17.         dsn := fmt.Sprintf("host=%s port=%s user=%s dbname=%s sslmode=disable TimeZone=Asia/Shanghai password=%s", DbHost, DbPort, DbUser, DbName, DbPass)
  18.         DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
  19.         if err != nil {
  20.                 log.Fatalf("Connect to database failed, %v\n", err)
  21.         } else {
  22.                 log.Printf("Connect to database success, host: %s, port: %s, user: %s, dbname: %s\n", DbHost, DbPort, DbUser, DbName)
  23.         }
  24.         // 迁移数据表
  25.         DB.AutoMigrate(&User{})
  26. }
复制代码

  • user.go
  1. package models
  2. import (
  3.         "errors"
  4.         "gin-jwt/utils/token"
  5.         "html"
  6.         "strings"
  7.         "golang.org/x/crypto/bcrypt"
  8.         "gorm.io/gorm"
  9. )
  10. type User struct {
  11.         gorm.Model
  12.         Username string `gorm:"size:255;not null;unique" json:"username"`
  13.         Password string `gorm:"size:255;not null;" json:"password"`
  14. }
  15. func (u *User) SaveUser() (*User, error) {
  16.         err := DB.Create(&u).Error
  17.         if err != nil {
  18.                 return &User{}, err
  19.         }
  20.         return u, nil
  21. }
  22. // 使用gorm的hook在保存密码前对密码进行hash
  23. func (u *User) BeforeSave(tx *gorm.DB) error {
  24.         hashedPassword, err := bcrypt.GenerateFromPassword([]byte(u.Password), bcrypt.DefaultCost)
  25.         if err != nil {
  26.                 return err
  27.         }
  28.         u.Password = string(hashedPassword)
  29.         u.Username = html.EscapeString(strings.TrimSpace(u.Username))
  30.         return nil
  31. }
  32. // 返回前将用户密码置空
  33. func (u *User) PrepareGive() {
  34.         u.Password = ""
  35. }
  36. // 对哈希加密的密码进行比对校验
  37. func VerifyPassword(password, hashedPassword string) error {
  38.         return bcrypt.CompareHashAndPassword([]byte(hashedPassword), []byte(password))
  39. }
  40. func LoginCheck(username, password string) (string, error) {
  41.         var err error
  42.         u := User{}
  43.         err = DB.Model(User{}).Where("username = ?", username).Take(&u).Error
  44.         if err != nil {
  45.                 return "", err
  46.         }
  47.         err = VerifyPassword(password, u.Password)
  48.         if err != nil && err == bcrypt.ErrMismatchedHashAndPassword {
  49.                 return "", err
  50.         }
  51.         token, err := token.GenerateToken(u.ID)
  52.         if err != nil {
  53.                 return "", err
  54.         }
  55.         return token, nil
  56. }
  57. func GetUserByID(uid uint) (User, error) {
  58.         var u User
  59.         if err := DB.First(&u, uid).Error; err != nil {
  60.                 return u, errors.New("user not found")
  61.         }
  62.         u.PrepareGive()
  63.         return u, nil
  64. }
复制代码
utils


  • token/token.go
  1. package token
  2. import (
  3.         "fmt"
  4.         "os"
  5.         "strconv"
  6.         "strings"
  7.         "time"
  8.         "github.com/gin-gonic/gin"
  9.         "github.com/golang-jwt/jwt/v5"
  10. )
  11. func GenerateToken(user_id uint) (string, error) {
  12.         token_lifespan, err := strconv.Atoi(os.Getenv("TOKEN_HOUR_LIFESPAN"))
  13.         if err != nil {
  14.                 return "", err
  15.         }
  16.         claims := jwt.MapClaims{}
  17.         claims["authorized"] = true
  18.         claims["user_id"] = user_id
  19.         claims["exp"] = time.Now().Add(time.Hour * time.Duration(token_lifespan)).Unix()
  20.         token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
  21.         return token.SignedString([]byte(os.Getenv("API_SECRET")))
  22. }
  23. func TokenValid(c *gin.Context) error {
  24.         tokenString := ExtractToken(c)
  25.         fmt.Println(tokenString)
  26.         _, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
  27.                 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  28.                         return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
  29.                 }
  30.                 return []byte(os.Getenv("API_SECRET")), nil
  31.         })
  32.         if err != nil {
  33.                 return err
  34.         }
  35.         return nil
  36. }
  37. // 从请求头中获取token
  38. func ExtractToken(c *gin.Context) string {
  39.         bearerToken := c.GetHeader("Authorization")
  40.         if len(strings.Split(bearerToken, " ")) == 2 {
  41.                 return strings.Split(bearerToken, " ")[1]
  42.         }
  43.         return ""
  44. }
  45. // 从jwt中解析出user_id
  46. func ExtractTokenID(c *gin.Context) (uint, error) {
  47.         tokenString := ExtractToken(c)
  48.         token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
  49.                 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
  50.                         return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
  51.                 }
  52.                 return []byte(os.Getenv("API_SECRET")), nil
  53.         })
  54.         if err != nil {
  55.                 return 0, err
  56.         }
  57.         claims, ok := token.Claims.(jwt.MapClaims)
  58.         // 如果jwt有效,将user_id转换为浮点数字符串,然后再转换为 uint32
  59.         if ok && token.Valid {
  60.                 uid, err := strconv.ParseUint(fmt.Sprintf("%.0f", claims["user_id"]), 10, 32)
  61.                 if err != nil {
  62.                         return 0, err
  63.                 }
  64.                 return uint(uid), nil
  65.         }
  66.         return 0, nil
  67. }
复制代码
middlewares


  • middlewares.go
  1. package middlewares
  2. import (
  3.         "gin-jwt/utils/token"
  4.         "net/http"
  5.         "github.com/gin-gonic/gin"
  6. )
  7. func JwtAuthMiddleware() gin.HandlerFunc {
  8.         return func(c *gin.Context) {
  9.                 err := token.TokenValid(c)
  10.                 if err != nil {
  11.                         c.String(http.StatusUnauthorized, err.Error())
  12.                         c.Abort()
  13.                         return
  14.                 }
  15.                 c.Next()
  16.         }
  17. }
复制代码
参考


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

卖不甜枣

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表