通过 PromptTemplate 天生干净的 SQL 查询语句并实行SQL查询语句

[复制链接]
发表于 昨天 06:56 | 显示全部楼层 |阅读模式
题目形貌

在利用 LangChain 和 Llama 模子天生 SQL 查询时,碰到了 sqlite3.OperationalError 错误。错误信息如下:
  1. OperationalError: (sqlite3.OperationalError) near "```sql
  2. SELECT Name
  3. FROM MediaType
  4. LIMIT 5;
  5. ```": syntax error
  6. [SQL: ```sql
  7. SELECT Name
  8. FROM MediaType
  9. LIMIT 5;
  10. ```]
复制代码
错误发生的缘故起因是天生的 SQL 查询包罗了不须要的 Markdown 代码块标志 ```,也就是在天生SQL语句的过程中,产生了其他的不干净文本,导致 SQL 语法错误。
终极办理方案

通过修改 PromptTemplate 来天生干净的 SQL 查询,确保天生的查询不包罗任何 Markdown 代码块标志或附加批评。以下是办理方案的具体步调和代码实现:
1. 初始化环境

起首,初始化所需的环境变量和模子:
  1. import getpass
  2. import os
  3. from langchain.chat_models import init_chat_model
  4. from langchain_core.prompts import PromptTemplate
  5. from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
  6. # 如果没有设置 GROQ_API_KEY,则提示用户输入
  7. if not os.environ.get("GROQ_API_KEY"):
  8.     os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")
  9. # 初始化 Llama 模型,使用 Groq 后端
  10. llm = init_chat_model("llama-3.3-70b-versatile", model_provider="groq", temperature=0)
复制代码
2. 界说自界说提示模板

界说一个自界说的 PromptTemplate,用于天生干净的 SQL 查询:
  1. custom_prompt = PromptTemplate(
  2.     input_variables=["dialect", "input", "table_info", "top_k"],
  3.     template="""You are a SQL expert using {dialect}.
  4. Given the following table schema:
  5. {table_info}
  6. Generate a syntactically correct SQL query to answer the question: "{input}".
  7. Limit the results to at most {top_k} rows.
  8. Return only the SQL query without any additional commentary or Markdown formatting.
  9. """
  10. )
复制代码
3. 创建 SQL 查询链

创建一个 SQL 查询链,并利用自界说提示模板:
  1. write_query = create_sql_query_chain(llm, db, prompt=custom_prompt)
复制代码
4. 构造输入数据字典

构造输入数据字典,此中包罗方言、表结构、题目和行数限定:
  1. input_data = {
  2.     "dialect": db.dialect,                    # 数据库方言,如 "sqlite"
  3.     "table_info": db.get_table_info(),        # 表结构信息
  4.     "input": "What name of MediaType is?",    # 问题
  5.     "top_k": 5                                # 行数限制
  6. }
复制代码
5. 调用链天生并实行 SQL 查询

调用链天生 SQL 查询,确保天生的查询不包罗 Markdown 代码块标志,然后实行查询并打印结果:
  1. response = write_query.invoke(input_data)
  2. query = response["query"]
  3. # 执行 SQL 查询并打印结果
  4. execute_query = QuerySQLDataBaseTool(db=db)
  5. result = execute_query.invoke({"query": query})
  6. print(result)
复制代码
总结

通过修改 PromptTemplate 来天生 SQL 查询时,明确要求返回的 SQL 查询不包罗任何附加批评或 Markdown 格式,确保天生的 SQL 查询是干净的、可实行的。如答应以制止由多余的标志导致的 SQL 语法错误。
末了提供完备代码:
  1. import getpassimport osfrom langchain.chat_models import init_chat_modelfrom langchain_core.prompts import PromptTemplatefrom langchain_community.tools.sql_database.tool import QuerySQLDataBaseToolfrom dotenv import load_dotenvfrom pyprojroot import herefrom langchain.chains import create_sql_query_chainfrom langchain_community.agent_toolkits import create_sql_agentfrom langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkitfrom langchain_community.utilities import SQLDatabaseload_dotenv()# 如果没有设置 GROQ_API_KEY,则提示用户输入if not os.environ.get("GROQ_API_KEY"):    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter API key for Groq: ")    sqldb_directory = here("data/Chinook.db")db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")table_info = db.get_table_info(["Album"])  # 留意必要通报列表print(f"\n Original table info: {table_info}")   #  初始化 Llama 模子,利用 Groq 后端llm = init_chat_model("llama-3.3-70b-specdec", model_provider="groq", temperature=0)# 界说自界说提示模板,用于天生 SQL 查询custom_prompt = PromptTemplate(
  2.     input_variables=["dialect", "input", "table_info", "top_k"],
  3.     template="""You are a SQL expert using {dialect}.
  4. Given the following table schema:
  5. {table_info}
  6. Generate a syntactically correct SQL query to answer the question: "{input}".
  7. Limit the results to at most {top_k} rows.
  8. Return only the SQL query without any additional commentary or Markdown formatting.
  9. """
  10. )
  11. write_query  = create_sql_query_chain(llm, db,prompt=custom_prompt)# 构造输入数据字典,此中包罗方言、表结构、题目和行数限定input_data = {    "dialect": db.dialect,                    # 数据库方言,如 "sqlite"    "table_info": db.get_table_info(),          # 表结构信息    "question": "What name of MediaType is?",    "top_k": 5}# 调用链天生 SQL 查询,返回结果为一个字典,包罗键 "query"write_query_response = write_query.invoke(input_data)print('\n write_query result:',write_query_response)#实行SQL语句execute_query = QuerySQLDataBaseTool(db=db)execute_response = execute_query.invoke(write_query_response)print('\n execute_response result:',execute_response)#两个动作合起来搞成链chain = write_query | execute_queryresult_chain = chain.invoke(input_data)print('\n result_chain==',result_chain)
复制代码
输出:


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
继续阅读请点击广告

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
回复

使用道具 举报

×
登录参与点评抽奖,加入IT实名职场社区
去登录
快速回复 返回顶部 返回列表