基于langchain的长文本多迭代总结

打印 上一主题 下一主题

主题 1729|帖子 1729|积分 5187

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
之前我们讲到langchain的rag问答,有兴趣的同学可以做下回顾
langchain基于混元大模型的实时内容的RAG问答
今天我们来了解下如何基于前文的方案实现长文本总结
为什么需要文本总结

通常会议内容是冗长的,如果能够提取关键信息的话,能够帮我们节省大量的时间
模型不能总结吗,为什么单独提出来长文本这个概念

大部分模型都会限制输入长度,如果会议长度超出了模型的限制则无法进行总结
方案

langchain提供了多种方案供我们选择,https://python.langchain.com/v0.1/docs/use_cases/summarization/

  • stuff:全文本总结,将整个文本全部投入模型;这样仍然大概会超出模型
  • MapReduce:将文本拆成n个小段,每个小段分别总结,然后再将最终的内容一起总结;这样虽然能解决问题,但是大概会破坏文本的上下文导致最终的效果不理想
  • refine:和MapReduce相似的是将文本拆成n个小段,但是会以循环的方式先总结第一段,然后将第一段的总结效果和第二段再总结以此类推,此方法能够更好的保留原文的语义
难点


  • 代码实现
  • 流式返回
  • 如何确定是末了一轮的返回(在流式响应的环境下,每轮都会返回总结效果,那么入会确定是末了一轮并返回个前端)
实现

由于langchain的部分实现比力紧凑,导致做二次开辟不是很方便,所以大概有部分修改源码的地方
1.创建文本加载工具,用于加载文本
AttachCode
  1. from typing import Dict, Optional
  2. from langchain.chains.combine_documents.base import AnalyzeDocumentChain
  3. from langchain_community.document_loaders import WebBaseLoader
  4. from langchain_core.callbacks import CallbackManagerForChainRun
  5. from modules.BasinessException import BusinessException
  6. from modules.resultCodeEnum import ResultCodeEnum
  7. from service.SubtitleService import SubtitleService
  8. from utils import constants
  9. from utils.logger import logger
  10. class download_summarize_chain(AnalyzeDocumentChain):
  11.     def _call(
  12.             self,
  13.             inputs: Dict[str, str],
  14.             run_manager: Optional[CallbackManagerForChainRun] = None,
  15.     ) -> dict[str, str]:
  16.         docs = self.get_docs(inputs, run_manager)
  17.         # Other keys are assumed to be needed for LLM prediction
  18.         other_keys: Dict = {k: v for k, v in inputs.items() if k != self.input_key}
  19.         other_keys[self.combine_docs_chain.input_key] = docs
  20.         _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
  21.         return self.combine_docs_chain(
  22.             other_keys, return_only_outputs=True, callbacks=_run_manager.get_child()
  23.         )
  24.     def get_docs(self, inputs, run_manager):
  25.         file_download_url = str(inputs[constants.TRANSCRIPTION_FILE_URL])
  26.         if file_download_url is not None and file_download_url.startswith("http"):
  27.             # 通过下载地址下载文件
  28.             loader = WebBaseLoader(file_download_url, None, False)
  29.             """Split document into chunks and pass to CombineDocumentsChain."""
  30.             document = loader.load()[0].page_content
  31.             if len(document) <= 0:
  32.                 logger.error(f"file not exists:{file_download_url}")
  33.                 raise BusinessException.new_instance_with_rce(400, ResultCodeEnum.EMPTY_CONTENT)
  34.         else:
  35.             # 通过企业id和会议id获取字幕
  36.             enterprise_id: str = run_manager.metadata.get(constants.ENTERPRISE_ID)
  37.             meeting_id: str = run_manager.metadata.get(constants.MEETING_ID)
  38.             logger.info(f"process task with llm:{enterprise_id}-{meeting_id}")
  39.             document = SubtitleService().fetch_subtitles(enterprise_id=enterprise_id, meeting_id=meeting_id)
  40.         docs = self.text_splitter.create_documents([document])
  41.         logger.info("number of splitting doc parts:{}", len(docs))
  42.         return docs
复制代码
3.refine chain
AttachCode
  1. """Load summarizing chains."""
  2. from typing import Any, Mapping, Optional, Protocol
  3. from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
  4. from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
  5. from langchain.chains.combine_documents.reduce import ReduceDocumentsChain
  6. from langchain.chains.combine_documents.stuff import StuffDocumentsChain
  7. from langchain.chains.llm import LLMChain
  8. from langchain.chains.summarize import map_reduce_prompt, refine_prompts, stuff_prompt
  9. from langchain_core.callbacks import Callbacks
  10. from langchain_core.language_models import BaseLanguageModel
  11. from langchain_core.prompts import BasePromptTemplate
  12. from adapters.langchain.chains.refine import RefineDocumentsChain
  13. class LoadingCallable(Protocol):
  14.     """Interface for loading the combine documents chain."""
  15.     def __call__(
  16.             self, llm: BaseLanguageModel, **kwargs: Any
  17.     ) -> BaseCombineDocumentsChain:
  18.         """Callable to load the combine documents chain."""
  19. def _load_stuff_chain(
  20.         llm: BaseLanguageModel,
  21.         prompt: BasePromptTemplate = stuff_prompt.PROMPT,
  22.         document_variable_name: str = "text",
  23.         verbose: Optional[bool] = None,
  24.         **kwargs: Any,
  25. ) -> StuffDocumentsChain:
  26.     llm_chain = LLMChain(llm=llm, prompt=prompt, verbose=verbose)  # type: ignore[arg-type]
  27.     # TODO: document prompt
  28.     return StuffDocumentsChain(
  29.         llm_chain=llm_chain,
  30.         document_variable_name=document_variable_name,
  31.         verbose=verbose,  # type: ignore[arg-type]
  32.         **kwargs,
  33.     )
  34. def _load_map_reduce_chain(
  35.         llm: BaseLanguageModel,
  36.         map_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
  37.         combine_prompt: BasePromptTemplate = map_reduce_prompt.PROMPT,
  38.         combine_document_variable_name: str = "text",
  39.         map_reduce_document_variable_name: str = "text",
  40.         collapse_prompt: Optional[BasePromptTemplate] = None,
  41.         reduce_llm: Optional[BaseLanguageModel] = None,
  42.         collapse_llm: Optional[BaseLanguageModel] = None,
  43.         verbose: Optional[bool] = None,
  44.         token_max: int = 3000,
  45.         callbacks: Callbacks = None,
  46.         *,
  47.         collapse_max_retries: Optional[int] = None,
  48.         **kwargs: Any,
  49. ) -> MapReduceDocumentsChain:
  50.     map_chain = LLMChain(
  51.         llm=llm,
  52.         prompt=map_prompt,
  53.         verbose=verbose,  # type: ignore[arg-type]
  54.         callbacks=callbacks,  # type: ignore[arg-type]
  55.     )
  56.     _reduce_llm = reduce_llm or llm
  57.     reduce_chain = LLMChain(
  58.         llm=_reduce_llm,
  59.         prompt=combine_prompt,
  60.         verbose=verbose,  # type: ignore[arg-type]
  61.         callbacks=callbacks,  # type: ignore[arg-type]
  62.     )
  63.     # TODO: document prompt
  64.     combine_documents_chain = StuffDocumentsChain(
  65.         llm_chain=reduce_chain,
  66.         document_variable_name=combine_document_variable_name,
  67.         verbose=verbose,  # type: ignore[arg-type]
  68.         callbacks=callbacks,
  69.     )
  70.     if collapse_prompt is None:
  71.         collapse_chain = None
  72.         if collapse_llm is not None:
  73.             raise ValueError(
  74.                 "collapse_llm provided, but collapse_prompt was not: please "
  75.                 "provide one or stop providing collapse_llm."
  76.             )
  77.     else:
  78.         _collapse_llm = collapse_llm or llm
  79.         collapse_chain = StuffDocumentsChain(
  80.             llm_chain=LLMChain(
  81.                 llm=_collapse_llm,
  82.                 prompt=collapse_prompt,
  83.                 verbose=verbose,  # type: ignore[arg-type]
  84.                 callbacks=callbacks,
  85.             ),
  86.             document_variable_name=combine_document_variable_name,
  87.         )
  88.     reduce_documents_chain = ReduceDocumentsChain(
  89.         combine_documents_chain=combine_documents_chain,
  90.         collapse_documents_chain=collapse_chain,
  91.         token_max=token_max,
  92.         verbose=verbose,  # type: ignore[arg-type]
  93.         callbacks=callbacks,
  94.         collapse_max_retries=collapse_max_retries,
  95.     )
  96.     return MapReduceDocumentsChain(
  97.         llm_chain=map_chain,
  98.         reduce_documents_chain=reduce_documents_chain,
  99.         document_variable_name=map_reduce_document_variable_name,
  100.         verbose=verbose,  # type: ignore[arg-type]
  101.         callbacks=callbacks,
  102.         **kwargs,
  103.     )
  104. def _load_refine_chain(
  105.         llm: BaseLanguageModel,
  106.         question_prompt: BasePromptTemplate = refine_prompts.PROMPT,
  107.         refine_prompt: BasePromptTemplate = refine_prompts.REFINE_PROMPT,
  108.         document_variable_name: str = "text",
  109.         initial_response_name: str = "existing_answer",
  110.         refine_llm: Optional[BaseLanguageModel] = None,
  111.         verbose: Optional[bool] = None,
  112.         **kwargs: Any,
  113. ) -> RefineDocumentsChain:
  114.     initial_chain = LLMChain(llm=llm, prompt=question_prompt, verbose=verbose)  # type: ignore[arg-type]
  115.     _refine_llm = refine_llm or llm
  116.     refine_chain = LLMChain(llm=_refine_llm, prompt=refine_prompt, verbose=verbose)  # type: ignore[arg-type]
  117.     return RefineDocumentsChain(
  118.         initial_llm_chain=initial_chain,
  119.         refine_llm_chain=refine_chain,
  120.         document_variable_name=document_variable_name,
  121.         initial_response_name=initial_response_name,
  122.         verbose=verbose,  # type: ignore[arg-type]
  123.         **kwargs,
  124.     )
  125. def load_summarize_chain(
  126.         llm: BaseLanguageModel,
  127.         chain_type: str = "stuff",
  128.         verbose: Optional[bool] = None,
  129.         **kwargs: Any,
  130. ) -> BaseCombineDocumentsChain:
  131.     """Load summarizing chain.
  132.     Args:
  133.         llm: Language Model to use in the chain.
  134.         chain_type: Type of document combining chain to use. Should be one of "stuff",
  135.             "map_reduce", and "refine".
  136.         verbose: Whether chains should be run in verbose mode or not. Note that this
  137.             applies to all chains that make up the final chain.
  138.     Returns:
  139.         A chain to use for summarizing.
  140.     """
  141.     loader_mapping: Mapping[str, LoadingCallable] = {
  142.         "stuff": _load_stuff_chain,
  143.         "map_reduce": _load_map_reduce_chain,
  144.         "refine": _load_refine_chain,
  145.     }
  146.     if chain_type not in loader_mapping:
  147.         raise ValueError(
  148.             f"Got unsupported chain type: {chain_type}. "
  149.             f"Should be one of {loader_mapping.keys()}"
  150.         )
  151.     return loader_mapping[chain_type](llm, verbose=verbose, **kwargs)
复制代码
4.调用chain
AttachCode
  1. """Combine documents by doing a first pass and then refining on more documents."""
  2. from __future__ import annotations
  3. from typing import Any, Dict, List, Tuple
  4. from langchain.chains.combine_documents.base import (
  5.     BaseCombineDocumentsChain,
  6. )
  7. from langchain.chains.llm import LLMChain
  8. from langchain_core.callbacks import Callbacks, dispatch_custom_event
  9. from langchain_core.documents import Document
  10. from langchain_core.prompts import BasePromptTemplate, format_document
  11. from langchain_core.prompts.prompt import PromptTemplate
  12. from pydantic import ConfigDict, Field, model_validator
  13. from utils.logger import logger
  14. def _get_default_document_prompt() -> PromptTemplate:
  15.     return PromptTemplate(input_variables=["page_content"], template="{page_content}")
  16. class RefineDocumentsChain(BaseCombineDocumentsChain):
  17.     """Combine documents by doing a first pass and then refining on more documents.
  18.     This algorithm first calls `initial_llm_chain` on the first document, passing
  19.     that first document in with the variable name `document_variable_name`, and
  20.     produces a new variable with the variable name `initial_response_name`.
  21.     Then, it loops over every remaining document. This is called the "refine" step.
  22.     It calls `refine_llm_chain`,
  23.     passing in that document with the variable name `document_variable_name`
  24.     as well as the previous response with the variable name `initial_response_name`.
  25.     Example:
  26.         .. code-block:: python
  27.             from langchain.chains import RefineDocumentsChain, LLMChain
  28.             from langchain_core.prompts import PromptTemplate
  29.             from langchain_community.llms import OpenAI
  30.             # This controls how each document will be formatted. Specifically,
  31.             # it will be passed to `format_document` - see that function for more
  32.             # details.
  33.             document_prompt = PromptTemplate(
  34.                 input_variables=["page_content"],
  35.                  template="{page_content}"
  36.             )
  37.             document_variable_name = "context"
  38.             llm = OpenAI()
  39.             # The prompt here should take as an input variable the
  40.             # `document_variable_name`
  41.             prompt = PromptTemplate.from_template(
  42.                 "Summarize this content: {context}"
  43.             )
  44.             initial_llm_chain = LLMChain(llm=llm, prompt=prompt)
  45.             initial_response_name = "prev_response"
  46.             # The prompt here should take as an input variable the
  47.             # `document_variable_name` as well as `initial_response_name`
  48.             prompt_refine = PromptTemplate.from_template(
  49.                 "Here's your first summary: {prev_response}. "
  50.                 "Now add to it based on the following context: {context}"
  51.             )
  52.             refine_llm_chain = LLMChain(llm=llm, prompt=prompt_refine)
  53.             chain = RefineDocumentsChain(
  54.                 initial_llm_chain=initial_llm_chain,
  55.                 refine_llm_chain=refine_llm_chain,
  56.                 document_prompt=document_prompt,
  57.                 document_variable_name=document_variable_name,
  58.                 initial_response_name=initial_response_name,
  59.             )
  60.     """
  61.     initial_llm_chain: LLMChain
  62.     """LLM chain to use on initial document."""
  63.     refine_llm_chain: LLMChain
  64.     """LLM chain to use when refining."""
  65.     document_variable_name: str
  66.     """The variable name in the initial_llm_chain to put the documents in.
  67.     If only one variable in the initial_llm_chain, this need not be provided."""
  68.     initial_response_name: str
  69.     """The variable name to format the initial response in when refining."""
  70.     document_prompt: BasePromptTemplate = Field(
  71.         default_factory=_get_default_document_prompt
  72.     )
  73.     """Prompt to use to format each document, gets passed to `format_document`."""
  74.     return_intermediate_steps: bool = False
  75.     """Return the results of the refine steps in the output."""
  76.     @property
  77.     def output_keys(self) -> List[str]:
  78.         """Expect input key.
  79.         :meta private:
  80.         """
  81.         _output_keys = super().output_keys
  82.         if self.return_intermediate_steps:
  83.             _output_keys = _output_keys + ["intermediate_steps"]
  84.         return _output_keys
  85.     model_config = ConfigDict(
  86.         arbitrary_types_allowed=True,
  87.         extra="forbid",
  88.     )
  89.     @model_validator(mode="before")
  90.     @classmethod
  91.     def get_return_intermediate_steps(cls, values: Dict) -> Any:
  92.         """For backwards compatibility."""
  93.         if "return_refine_steps" in values:
  94.             values["return_intermediate_steps"] = values["return_refine_steps"]
  95.             del values["return_refine_steps"]
  96.         return values
  97.     @model_validator(mode="before")
  98.     @classmethod
  99.     def get_default_document_variable_name(cls, values: Dict) -> Any:
  100.         """Get default document variable name, if not provided."""
  101.         if "initial_llm_chain" not in values:
  102.             raise ValueError("initial_llm_chain must be provided")
  103.         llm_chain_variables = values["initial_llm_chain"].prompt.input_variables
  104.         if "document_variable_name" not in values:
  105.             if len(llm_chain_variables) == 1:
  106.                 values["document_variable_name"] = llm_chain_variables[0]
  107.             else:
  108.                 raise ValueError(
  109.                     "document_variable_name must be provided if there are "
  110.                     "multiple llm_chain input_variables"
  111.                 )
  112.         else:
  113.             if values["document_variable_name"] not in llm_chain_variables:
  114.                 raise ValueError(
  115.                     f"document_variable_name {values['document_variable_name']} was "
  116.                     f"not found in llm_chain input_variables: {llm_chain_variables}"
  117.                 )
  118.         return values
  119.     def combine_docs(
  120.             self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
  121.     ) -> Tuple[str, dict]:
  122.         """Combine by mapping first chain over all, then stuffing into final chain.
  123.         Args:
  124.             docs: List of documents to combine
  125.             callbacks: Callbacks to be passed through
  126.             **kwargs: additional parameters to be passed to LLM calls (like other
  127.                 input variables besides the documents)
  128.         Returns:
  129.             The first element returned is the single string output. The second
  130.             element returned is a dictionary of other keys to return.
  131.         """
  132.         inputs = self._construct_initial_inputs(docs, **kwargs)
  133.         dispatch_custom_event("last_doc_mark", {"chunk": False})
  134.         doc_length = len(docs)
  135.         if doc_length == 1:
  136.             dispatch_custom_event("last_doc_mark", {"chunk": True})
  137.         logger.info(f"refine_docs index:1/{doc_length} of {kwargs}")
  138.         res = self.initial_llm_chain.predict(callbacks=callbacks, **inputs)
  139.         refine_steps = [res]
  140.         for index, doc in enumerate(docs[1:], start=1):
  141.             logger.info(f"refine_docs index:{index+1}/{doc_length} of {kwargs}")
  142.             if index == doc_length - 1:
  143.                 dispatch_custom_event("last_doc_mark", {"chunk": True})
  144.             base_inputs = self._construct_refine_inputs(doc, res)
  145.             inputs = {**base_inputs, **kwargs}
  146.             res = self.refine_llm_chain.predict(callbacks=callbacks, **inputs)
  147.             refine_steps.append(res)
  148.         logger.info(f"refine_docs finished of {kwargs}, result:{res}")
  149.         return self._construct_result(refine_steps, res)
  150.     async def acombine_docs(
  151.             self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
  152.     ) -> Tuple[str, dict]:
  153.         """Async combine by mapping a first chain over all, then stuffing
  154.          into a final chain.
  155.         Args:
  156.             docs: List of documents to combine
  157.             callbacks: Callbacks to be passed through
  158.             **kwargs: additional parameters to be passed to LLM calls (like other
  159.                 input variables besides the documents)
  160.         Returns:
  161.             The first element returned is the single string output. The second
  162.             element returned is a dictionary of other keys to return.
  163.         """
  164.         inputs = self._construct_initial_inputs(docs, **kwargs)
  165.         res = await self.initial_llm_chain.apredict(callbacks=callbacks, **inputs)
  166.         refine_steps = [res]
  167.         for doc in docs[1:]:
  168.             base_inputs = self._construct_refine_inputs(doc, res)
  169.             inputs = {**base_inputs, **kwargs}
  170.             res = await self.refine_llm_chain.apredict(callbacks=callbacks, **inputs)
  171.             refine_steps.append(res)
  172.         return self._construct_result(refine_steps, res)
  173.     def _construct_result(self, refine_steps: List[str], res: str) -> Tuple[str, dict]:
  174.         if self.return_intermediate_steps:
  175.             extra_return_dict = {"intermediate_steps": refine_steps}
  176.         else:
  177.             extra_return_dict = {}
  178.         return res, extra_return_dict
  179.     def _construct_refine_inputs(self, doc: Document, res: str) -> Dict[str, Any]:
  180.         return {
  181.             self.document_variable_name: format_document(doc, self.document_prompt),
  182.             self.initial_response_name: res,
  183.         }
  184.     def _construct_initial_inputs(
  185.             self, docs: List[Document], **kwargs: Any
  186.     ) -> Dict[str, Any]:
  187.         base_info = {"page_content": docs[0].page_content}
  188.         base_info.update(docs[0].metadata)
  189.         document_info = {k: base_info[k] for k in self.document_prompt.input_variables}
  190.         base_inputs: dict = {
  191.             self.document_variable_name: self.document_prompt.format(**document_info)
  192.         }
  193.         inputs = {**base_inputs, **kwargs}
  194.         return inputs
  195.     @property
  196.     def _chain_type(self) -> str:
  197.         return "refine_documents_chain"
复制代码
5.过滤末了一次迭代
AttachCode
  1. def process(tool: BaseTool, prompt_type: QuestionTypeEnum, input_dict: dict,
  2.             run_manager: Optional[CallbackManagerForToolRun] = None):
  3.     # 获取模型实例
  4.     model_type = ModelTypeEnum.from_string(run_manager.metadata.get(constants.MODEL_TYPE))
  5.     model_instance = ModelAdapter.get_model_instance(model_type)
  6.     # 提示词模板集合
  7.     prompt_template = PromptSynchronizer.get_prompt_template(model_type=model_type, questionType=prompt_type)
  8.     prompt_map = json.loads(prompt_template)
  9.     refine_prompt = PromptTemplate.from_template(prompt_map["refine_template"], template_format="f-string")
  10.     question_prompt = PromptTemplate.from_template(prompt_map["prompt"])
  11.     logger.info("invoke tool input_dicts:{}",input_dict)
  12.     combine_docs_chain=load_summarize_chain(llm=model_instance,
  13.                                             chain_type="refine",
  14.                                             question_prompt=question_prompt,
  15.                                             refine_prompt=refine_prompt,
  16.                                             return_intermediate_steps=True,
  17.                                             input_key="text",
  18.                                             output_key="existing_answer",
  19.                                             verbose=True)
  20.     res = (tool.pre_handler
  21.            | download_summarize_chain(combine_docs_chain=combine_docs_chain,
  22.                                   text_splitter=model_instance.get_text_splitter(),
  23.                                   verbose=True,
  24.                                   input_key="input_document")
  25.            | tool.post_handler).invoke(input={"input_document": "", **input_dict}, config=RunnableConfig())
  26.     return res
复制代码
此笔记由idea插件辅助生成
idea插件推荐 AnNote - IntelliJ IDEs Plugin | Marketplace 75 折扣头:
MGRYF-TJW4N-WZMSJ-MZDLD-LVGJH
BTKQ8-XZLPH-L3QH3-MPKBH-BP9RR
本文由博客群发一文多发等运营工具平台 OpenWrite 发布

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

勿忘初心做自己

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表