深度学习系列75:sql大模型工具vanna

[复制链接]
发表于 2025-1-24 01:05:50 | 显示全部楼层 |阅读模式
1. 概述

vanna是一个可以将天然语言转为sql的工具。简朴的demo如下:
  1. !pip install vanna
  2. import vanna
  3. from vanna.remote import VannaDefault
  4. vn = VannaDefault(model='chinook', api_key=vanna.get_api_key('my-email@example.com'))
  5. vn.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')
  6. vn.ask("What are the top 10 albums by sales?")
复制代码
执行下面的代码运行图形界面
  1. from vanna.flask import VannaFlaskApp
  2. VannaFlaskApp(vn).run()
复制代码
2. 配置

数据库可以是任何数据库,好比mysql如下:
  1. import pandas as pd
  2. import psycopg2
  3. def run_sql(sql):
  4.     conn = psycopg2.connect(
  5.         host="localhost",
  6.         database="my_database",
  7.         user="my_user",
  8.         password="my_password"
  9.     )
  10.     return pd.read_sql(sql, conn)
  11. vn.run_sql = run_sql
  12. vn.run_sql_is_set = True
复制代码
向量数据库稍微贫苦一些,目前支持的包括:
参考代码如下:
  1. from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
  2. class MyVanna(ChromaDB_VectorStore):
  3.     def __init__(self, config=None):
  4.         ChromaDB_VectorStore.__init__(self, config=config)
  5. vn = MyVanna(config={'path': '/path/to/chromadb'})
复制代码
3. 训练

训练数据可以是:DDL、documentation、sql以及Question-SQL Pairs
  1. vn.train(ddl="CREATE TABLE my_table (id INT, name TEXT)")
  2. vn.train(documentation="Our business defines XYZ as ABC")
  3. vn.train(sql="SELECT col1, col2, col3 FROM my_table")
复制代码
可以设置auto_train = True
4. 询问

  1. vn.ask("What are the top 10 customers by sales?")
复制代码
它包罗下列几个函数:
  1. vn.generate_sql
  2. vn.run_sql
  3. vn.generate_plotly_code
  4. vn.get_plotly_figure
复制代码
visualize=False
5. 启用服务

参考https://github.com/vanna-ai/vanna-flask,将LLM、embedding、vectorStore都改造成自己的代码。
首先是LLM,改造框架为:
  1. from vanna.base import VannaBase
  2. class MyLLM(VannaBase):
  3.     def __init__(self,config=None):
  4.         VannaBase.__init__(self, config=config)
  5.         ...
  6.    def system_message(self, message: str) -> any:
  7.         return {"role": "system", "content": message}
  8.     def user_message(self, message: str) -> any:
  9.         return {"role": "user", "content": message}
  10.     def assistant_message(self, message: str) -> any:
  11.         return {"role": "assistant", "content": message}
  12.     def submit_prompt(self, prompt, **kwargs) -> str:
  13.             ...
复制代码
然后是embedding,必要定义encode_documents和encode_queries两个函数,比方:
  1. class BgeM3:
  2.     def __init__(self, url):
  3.         self.url = url
  4.     def encode_documents(self, docs):
  5.         ....
  6.     def encode_queries(self, queries):
  7.         ....
复制代码
接下来是vectorStore,我们使用milvus,它会主动调用config中的embedding_function,我们把它定义成上面的BegM3即可:
  1. class MyVanna(Milvus_VectorStore, QwenLLM):
  2.     def __init__(self, config=None):
  3.         Milvus_VectorStore.__init__(self, config=config)
  4.         QwenLLM.__init__(self, config=config)
  5. vn = MyVanna(config={'milvus_client': MilvusClient(...),'embedding_function':BgeM3(...)})
复制代码
然后定义毗连的数据库,可以换成恣意的其他数据库:
  1. def run_sql(sql: str) -> pd.DataFrame:
  2.     cnx = mysql.connector.connect(...)
  3.     cursor = cnx.cursor()
  4.     cursor.execute(sql)
  5.     result = cursor.fetchall()
  6.     columns = cursor.column_names
  7.     df = pd.DataFrame(result, columns=columns)
  8.     return df
  9.    
  10. vn.run_sql = run_sql
  11. vn.run_sql_is_set = True
复制代码
接着执行python app.py即可启用服务,访问localhost:5000可以打开页面:

同时也可以调用接口:
  1. import requests
  2. response = requests.get(url+'/api/v0/get_training_data',headers={'Content-Type':'application/json'})
  3. response.json()
复制代码
所有可用的接口清单可以参考app.py。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

登录后关闭弹窗

登录参与点评抽奖  加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表