fastapi 调用ollama之下的sqlcoder模式进行对话操作数据库

打印 上一主题 下一主题

主题 991|帖子 991|积分 2973

  1. from fastapi import FastAPI, HTTPException, Request
  2. from pydantic import BaseModel
  3. import ollama
  4. import mysql.connector
  5. from mysql.connector.cursor import MySQLCursor
  6. import json
  7. app = FastAPI()
  8. # 数据库连接配置
  9. DB_CONFIG = {
  10.     "database": "web",        # 您的数据库名,用于存储业务数据
  11.     "user": "root",          # 数据库用户名,需要有读写权限
  12.     "password": "XXXXXX",    # 数据库密码,建议使用强密码
  13.     "host": "127.0.0.1",    # 数据库主机地址,本地开发环境使用localhost
  14.     "port": "3306"          # MySQL 默认端口,可根据实际配置修改
  15. }
  16. # 数据库连接函数
  17. def get_db_connection():
  18.     try:
  19.         conn = mysql.connector.connect(**DB_CONFIG)
  20.         return conn
  21.     except Exception as e:
  22.         raise HTTPException(status_code=500, detail=f"数据库连接失败: {str(e)}")
  23. class SQLRequest(BaseModel):
  24.     question: str
  25. def get_table_relationships():
  26.     """动态获取表之间的关联关系"""
  27.     conn = get_db_connection()
  28.     cur = conn.cursor()
  29.     try:
  30.         # 获取当前数据库名
  31.         cur.execute("SELECT DATABASE()")
  32.         db_name = cur.fetchone()[0]
  33.         
  34.         # 获取外键关系
  35.         cur.execute("""
  36.             SELECT
  37.                 TABLE_NAME,
  38.                 COLUMN_NAME,
  39.                 REFERENCED_TABLE_NAME,
  40.                 REFERENCED_COLUMN_NAME
  41.             FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
  42.             WHERE TABLE_SCHEMA = %s
  43.                 AND REFERENCED_TABLE_NAME IS NOT NULL
  44.             ORDER BY TABLE_NAME, COLUMN_NAME
  45.         """, (db_name,))
  46.         
  47.         relationships = []
  48.         for row in rows:
  49.             table_name, column_name, ref_table, ref_column = row
  50.             relationships.append(
  51.                 f"-- {table_name}.{column_name} can be joined with {ref_table}.{ref_column}"
  52.             )
  53.         
  54.         return "\n".join(relationships) if relationships else "-- No foreign key relationships found"
  55.         
  56.     finally:
  57.         cur.close()
  58.         conn.close()
  59. def get_database_schema():
  60.     """获取MySQL数据库表结构,以CREATE TABLE格式返回"""
  61.     conn = get_db_connection()
  62.     cur = conn.cursor()
  63.     try:
  64.         # 获取当前数据库名
  65.         cur.execute("SELECT DATABASE()")
  66.         db_name = cur.fetchone()[0]
  67.         
  68.         # 获取所有表的结构信息
  69.         cur.execute("""
  70.             SELECT
  71.                 t.TABLE_NAME,
  72.                 c.COLUMN_NAME,
  73.                 c.COLUMN_TYPE,
  74.                 c.IS_NULLABLE,
  75.                 c.COLUMN_KEY,
  76.                 c.COLUMN_COMMENT
  77.             FROM INFORMATION_SCHEMA.TABLES t
  78.             JOIN INFORMATION_SCHEMA.COLUMNS c
  79.                 ON t.TABLE_NAME = c.TABLE_NAME
  80.             WHERE t.TABLE_SCHEMA = %s
  81.                 AND t.TABLE_TYPE = 'BASE TABLE'
  82.             ORDER BY t.TABLE_NAME, c.ORDINAL_POSITION
  83.         """, (db_name,))
  84.         
  85.         rows = cur.fetchall()
  86.         
  87.         schema = []
  88.         current_table = None
  89.         table_columns = []
  90.         
  91.         for row in rows:
  92.             table_name, column_name, column_type, nullable, key, comment = row
  93.             
  94.             if current_table != table_name:
  95.                 if current_table is not None:
  96.                     schema.append(f"CREATE TABLE {current_table} (\n" +
  97.                                 ",\n".join(table_columns) +
  98.                                 "\n);\n")
  99.                 current_table = table_name
  100.                 table_columns = []
  101.             
  102.             # 构建列定义
  103.             column_def = f"  {column_name} {column_type.upper()}"
  104.             if key == "PRI":
  105.                 column_def += " PRIMARY KEY"
  106.             elif nullable == "NO":
  107.                 column_def += " NOT NULL"
  108.                
  109.             if comment:
  110.                 column_def += f" -- {comment}"
  111.                
  112.             table_columns.append(column_def)
  113.         
  114.         # 添加最后一个表
  115.         if current_table is not None:
  116.             schema.append(f"CREATE TABLE {current_table} (\n" +
  117.                         ",\n".join(table_columns) +
  118.                         "\n);\n")
  119.             
  120.         return "\n".join(schema)
  121.     finally:
  122.         cur.close()
  123.         conn.close()
  124. def get_chinese_table_mapping():
  125.     """动态生成表名的中文映射"""
  126.     conn = get_db_connection()
  127.     cur = conn.cursor()
  128.     try:
  129.         # 获取所有表的注释信息
  130.         cur.execute("""
  131.             SELECT
  132.                 t.TABLE_NAME,
  133.                 t.TABLE_COMMENT
  134.             FROM information_schema.TABLES t
  135.             WHERE t.TABLE_SCHEMA = DATABASE()
  136.             ORDER BY t.TABLE_NAME
  137.         """)
  138.         
  139.         mappings = []
  140.         for table_name, table_comment in cur.fetchall():
  141.             # 生成表的中文名称
  142.             chinese_name = table_name
  143.             if table_name.startswith('web_'):
  144.                 chinese_name = table_name.replace('web_', '').replace('_', '')
  145.             if table_comment:
  146.                 chinese_name = table_comment.split('--')[0].strip()
  147.                 # 如果中文名称以"表"结尾,则去掉"表"字
  148.                 if chinese_name.endswith('表'):
  149.                     chinese_name = chinese_name[:-1]
  150.             
  151.             mappings.append(f'           - "{chinese_name}" -> {table_name} table')
  152.         
  153.         return "\n".join(mappings)
  154.     finally:
  155.         cur.close()
  156.         conn.close()
  157. @app.post("/query")
  158. async def query_database(request: Request):
  159.     try:
  160.         # 获取请求体数据并确保正确处理中文
  161.         body = await request.body()
  162.         try:
  163.             request_data = json.loads(body.decode('utf-8'))
  164.         except UnicodeDecodeError:
  165.             request_data = json.loads(body.decode('gbk'))
  166.         
  167.         question = request_data.get('question')
  168.         print(f"收到问题: {question}")  # 调试日志
  169.         
  170.         if not question:
  171.             raise HTTPException(status_code=400, detail="缺少 question 参数")
  172.             
  173.         # 获取数据库结构
  174.         db_schema = get_database_schema()
  175.         #print(f"数据库结构: {db_schema}")  # 调试日志
  176.         
  177.         # 获取中文映射并打印
  178.         chinese_mapping = get_chinese_table_mapping()
  179.         #print(f"表映射关系:\n{chinese_mapping}")  # 添加这行来打印映射
  180.         
  181.         # 修改 prompt 使用更严格的指导
  182.         prompt = f"""
  183.         ### Instructions:
  184.         Convert Chinese question to MySQL query. Follow these rules strictly:
  185.         1. ONLY return a valid SELECT SQL query
  186.         2. Use EXACT table names from the mapping below
  187.         3. DO NOT use any table that's not in the mapping
  188.         4. For Chinese terms, use these exact mappings:
  189. {chinese_mapping}
  190.         ### Examples:
  191.         Question: 所有装修记录
  192.         SQL: SELECT * FROM web_decoration ORDER BY id;
  193.         Question: 查询装修
  194.         SQL: SELECT * FROM web_decoration ORDER BY id;
  195.         ### Database Schema:
  196.         {db_schema}
  197.         ### Question:
  198.         {question}
  199.         ### SQL Query:
  200.         """
  201.         
  202.         # 获取 SQL 查询
  203.         response = ollama.chat(model='sqlcoder:latest',
  204.                              messages=[{'role': 'user', 'content': prompt}])
  205.         sql_query = response['message']['content'].strip()
  206.         
  207.         print(f"生成的SQL: {sql_query}")  # 调试日志
  208.         
  209.         # 验证 SQL 查询
  210.         if "装修" in question and "web_decoration" not in sql_query:
  211.             sql_query = "SELECT * FROM web_decoration ORDER BY id"
  212.             
  213.         if not sql_query.upper().startswith('SELECT'):
  214.             raise HTTPException(status_code=400, detail="无效的SQL查询格式")
  215.             
  216.         conn = get_db_connection()
  217.         cur = conn.cursor(dictionary=True)
  218.         
  219.         try:
  220.             cur.execute(sql_query)
  221.             results = cur.fetchall()
  222.             return {
  223.                 "sql_query": sql_query,
  224.                 "results": results
  225.             }
  226.         except mysql.connector.Error as e:
  227.             raise HTTPException(status_code=400, detail=f"SQL 执行错误: {str(e)}")
  228.         finally:
  229.             cur.close()
  230.             conn.close()
  231.             
  232.     except json.JSONDecodeError as e:
  233.         raise HTTPException(status_code=400, detail=f"JSON 解析错误: {str(e)}")
  234.     except Exception as e:
  235.         raise HTTPException(status_code=500, detail=f"处理请求时生错误: {str(e)}")
  236. if __name__ == "__main__":
  237.     import uvicorn
  238.     port = 8666
  239.     print(f"Starting server on port {port}")
  240.     uvicorn.run(
  241.         "main:app",
  242.         host="0.0.0.0",
  243.         port=port,
  244.         reload=True
  245.     )
复制代码
我来为您解释这段代码的紧张功能和布局:
紧张功能

这是一个使用 FastAPI 框架开发的 API 服务,紧张用于将中文问题转换为 SQL 查询语句并执行。
核心组件说明


  • 数据库设置与毗连
  1. DB_CONFIG = {
  2.     "database": "web",
  3.     "user": "root",
  4.     "password": "XXXXXX",
  5.     "host": "127.0.0.1",
  6.     "port": "3306"
  7. }
复制代码
提供了 MySQL 数据库的基本毗连设置。

  • 紧张工具函数


  • get_table_relationships(): 获取数据库表之间的外键关系
  • get_database_schema(): 获取数据库表布局
  • get_chinese_table_mapping(): 天生表名的中文映射关系

  • 核心 API 端点
  1. @app.post("/query")
复制代码
这个端点接收中文问题,紧张处理流程:


  • 接收并解析用户的中文问题
  • 获取数据库布局和表映射
  • 使用 ollama 模子将中文转换为 SQL 查询
  • 执行 SQL 查询并返回效果

  • 智能转换功能
    使用 ollama 的 sqlcoder 模子将中文问题转换为 SQL 查询,包含:


  • 严酷的表名映射
  • SQL 查询验证
  • 错误处理机制
特点


  • 支持中文输入处理
  • 自动获取数据库布局
  • 动态天生中文表名映射
  • 美满的错误处理机制
  • 支持热重载的开发模式
使用示例

可以通过 POST 请求访问 /query 端点:
  1. {
  2.     "question": "查询所有装修记录"
  3. }
复制代码
服务会返回:
  1. {
  2.     "sql_query": "SELECT * FROM web_decoration ORDER BY id",
  3.     "results": [...]
  4. }
复制代码
安全特性


  • 数据库毗连错误处理
  • SQL 注入防护
  • 请求体编码自适应(支持 UTF-8 和 GBK)
  • 查询效果的安全封装
查看效果:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

tsx81429

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表