- from fastapi import FastAPI, HTTPException, Request
- from pydantic import BaseModel
- import ollama
- import mysql.connector
- from mysql.connector.cursor import MySQLCursor
- import json
- app = FastAPI()
- # 数据库连接配置
- DB_CONFIG = {
- "database": "web", # 您的数据库名,用于存储业务数据
- "user": "root", # 数据库用户名,需要有读写权限
- "password": "XXXXXX", # 数据库密码,建议使用强密码
- "host": "127.0.0.1", # 数据库主机地址,本地开发环境使用localhost
- "port": "3306" # MySQL 默认端口,可根据实际配置修改
- }
- # 数据库连接函数
- def get_db_connection():
- try:
- conn = mysql.connector.connect(**DB_CONFIG)
- return conn
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"数据库连接失败: {str(e)}")
- class SQLRequest(BaseModel):
- question: str
- def get_table_relationships():
- """动态获取表之间的关联关系"""
- conn = get_db_connection()
- cur = conn.cursor()
- try:
- # 获取当前数据库名
- cur.execute("SELECT DATABASE()")
- db_name = cur.fetchone()[0]
-
- # 获取外键关系
- cur.execute("""
- SELECT
- TABLE_NAME,
- COLUMN_NAME,
- REFERENCED_TABLE_NAME,
- REFERENCED_COLUMN_NAME
- FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
- WHERE TABLE_SCHEMA = %s
- AND REFERENCED_TABLE_NAME IS NOT NULL
- ORDER BY TABLE_NAME, COLUMN_NAME
- """, (db_name,))
-
- relationships = []
- for row in rows:
- table_name, column_name, ref_table, ref_column = row
- relationships.append(
- f"-- {table_name}.{column_name} can be joined with {ref_table}.{ref_column}"
- )
-
- return "\n".join(relationships) if relationships else "-- No foreign key relationships found"
-
- finally:
- cur.close()
- conn.close()
- def get_database_schema():
- """获取MySQL数据库表结构,以CREATE TABLE格式返回"""
- conn = get_db_connection()
- cur = conn.cursor()
- try:
- # 获取当前数据库名
- cur.execute("SELECT DATABASE()")
- db_name = cur.fetchone()[0]
-
- # 获取所有表的结构信息
- cur.execute("""
- SELECT
- t.TABLE_NAME,
- c.COLUMN_NAME,
- c.COLUMN_TYPE,
- c.IS_NULLABLE,
- c.COLUMN_KEY,
- c.COLUMN_COMMENT
- FROM INFORMATION_SCHEMA.TABLES t
- JOIN INFORMATION_SCHEMA.COLUMNS c
- ON t.TABLE_NAME = c.TABLE_NAME
- WHERE t.TABLE_SCHEMA = %s
- AND t.TABLE_TYPE = 'BASE TABLE'
- ORDER BY t.TABLE_NAME, c.ORDINAL_POSITION
- """, (db_name,))
-
- rows = cur.fetchall()
-
- schema = []
- current_table = None
- table_columns = []
-
- for row in rows:
- table_name, column_name, column_type, nullable, key, comment = row
-
- if current_table != table_name:
- if current_table is not None:
- schema.append(f"CREATE TABLE {current_table} (\n" +
- ",\n".join(table_columns) +
- "\n);\n")
- current_table = table_name
- table_columns = []
-
- # 构建列定义
- column_def = f" {column_name} {column_type.upper()}"
- if key == "PRI":
- column_def += " PRIMARY KEY"
- elif nullable == "NO":
- column_def += " NOT NULL"
-
- if comment:
- column_def += f" -- {comment}"
-
- table_columns.append(column_def)
-
- # 添加最后一个表
- if current_table is not None:
- schema.append(f"CREATE TABLE {current_table} (\n" +
- ",\n".join(table_columns) +
- "\n);\n")
-
- return "\n".join(schema)
- finally:
- cur.close()
- conn.close()
- def get_chinese_table_mapping():
- """动态生成表名的中文映射"""
- conn = get_db_connection()
- cur = conn.cursor()
- try:
- # 获取所有表的注释信息
- cur.execute("""
- SELECT
- t.TABLE_NAME,
- t.TABLE_COMMENT
- FROM information_schema.TABLES t
- WHERE t.TABLE_SCHEMA = DATABASE()
- ORDER BY t.TABLE_NAME
- """)
-
- mappings = []
- for table_name, table_comment in cur.fetchall():
- # 生成表的中文名称
- chinese_name = table_name
- if table_name.startswith('web_'):
- chinese_name = table_name.replace('web_', '').replace('_', '')
- if table_comment:
- chinese_name = table_comment.split('--')[0].strip()
- # 如果中文名称以"表"结尾,则去掉"表"字
- if chinese_name.endswith('表'):
- chinese_name = chinese_name[:-1]
-
- mappings.append(f' - "{chinese_name}" -> {table_name} table')
-
- return "\n".join(mappings)
- finally:
- cur.close()
- conn.close()
- @app.post("/query")
- async def query_database(request: Request):
- try:
- # 获取请求体数据并确保正确处理中文
- body = await request.body()
- try:
- request_data = json.loads(body.decode('utf-8'))
- except UnicodeDecodeError:
- request_data = json.loads(body.decode('gbk'))
-
- question = request_data.get('question')
- print(f"收到问题: {question}") # 调试日志
-
- if not question:
- raise HTTPException(status_code=400, detail="缺少 question 参数")
-
- # 获取数据库结构
- db_schema = get_database_schema()
- #print(f"数据库结构: {db_schema}") # 调试日志
-
- # 获取中文映射并打印
- chinese_mapping = get_chinese_table_mapping()
- #print(f"表映射关系:\n{chinese_mapping}") # 添加这行来打印映射
-
- # 修改 prompt 使用更严格的指导
- prompt = f"""
- ### Instructions:
- Convert Chinese question to MySQL query. Follow these rules strictly:
- 1. ONLY return a valid SELECT SQL query
- 2. Use EXACT table names from the mapping below
- 3. DO NOT use any table that's not in the mapping
- 4. For Chinese terms, use these exact mappings:
- {chinese_mapping}
- ### Examples:
- Question: 所有装修记录
- SQL: SELECT * FROM web_decoration ORDER BY id;
- Question: 查询装修
- SQL: SELECT * FROM web_decoration ORDER BY id;
- ### Database Schema:
- {db_schema}
- ### Question:
- {question}
- ### SQL Query:
- """
-
- # 获取 SQL 查询
- response = ollama.chat(model='sqlcoder:latest',
- messages=[{'role': 'user', 'content': prompt}])
- sql_query = response['message']['content'].strip()
-
- print(f"生成的SQL: {sql_query}") # 调试日志
-
- # 验证 SQL 查询
- if "装修" in question and "web_decoration" not in sql_query:
- sql_query = "SELECT * FROM web_decoration ORDER BY id"
-
- if not sql_query.upper().startswith('SELECT'):
- raise HTTPException(status_code=400, detail="无效的SQL查询格式")
-
- conn = get_db_connection()
- cur = conn.cursor(dictionary=True)
-
- try:
- cur.execute(sql_query)
- results = cur.fetchall()
- return {
- "sql_query": sql_query,
- "results": results
- }
- except mysql.connector.Error as e:
- raise HTTPException(status_code=400, detail=f"SQL 执行错误: {str(e)}")
- finally:
- cur.close()
- conn.close()
-
- except json.JSONDecodeError as e:
- raise HTTPException(status_code=400, detail=f"JSON 解析错误: {str(e)}")
- except Exception as e:
- raise HTTPException(status_code=500, detail=f"处理请求时生错误: {str(e)}")
- if __name__ == "__main__":
- import uvicorn
- port = 8666
- print(f"Starting server on port {port}")
- uvicorn.run(
- "main:app",
- host="0.0.0.0",
- port=port,
- reload=True
- )
复制代码 我来为您解释这段代码的紧张功能和布局:
紧张功能
这是一个使用 FastAPI 框架开发的 API 服务,紧张用于将中文问题转换为 SQL 查询语句并执行。
核心组件说明
- DB_CONFIG = {
- "database": "web",
- "user": "root",
- "password": "XXXXXX",
- "host": "127.0.0.1",
- "port": "3306"
- }
复制代码 提供了 MySQL 数据库的基本毗连设置。
- get_table_relationships(): 获取数据库表之间的外键关系
- get_database_schema(): 获取数据库表布局
- get_chinese_table_mapping(): 天生表名的中文映射关系
这个端点接收中文问题,紧张处理流程:
- 接收并解析用户的中文问题
- 获取数据库布局和表映射
- 使用 ollama 模子将中文转换为 SQL 查询
- 执行 SQL 查询并返回效果
- 智能转换功能
使用 ollama 的 sqlcoder 模子将中文问题转换为 SQL 查询,包含:
特点
- 支持中文输入处理
- 自动获取数据库布局
- 动态天生中文表名映射
- 美满的错误处理机制
- 支持热重载的开发模式
使用示例
可以通过 POST 请求访问 /query 端点:
- {
- "question": "查询所有装修记录"
- }
复制代码 服务会返回:
- {
- "sql_query": "SELECT * FROM web_decoration ORDER BY id",
- "results": [...]
- }
复制代码 安全特性
- 数据库毗连错误处理
- SQL 注入防护
- 请求体编码自适应(支持 UTF-8 和 GBK)
- 查询效果的安全封装
查看效果:
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |