【langchain 创建向量数据库非常美满的代码】

打印 上一主题 下一主题

主题 821|帖子 821|积分 2478

 - 支持faiss chroma两种数据库(faiss-cpu 支持旧数据库的合并)
- 支持制止重复文件embedding(hash)
- 支持浩繁文件格式
- 支持huggingface的embedding模型
- 优化了切分chunk的策略
- 支持多线程处理惩罚文件

   修改自langchain-chatchat, 增加了一些功能, 优化了splitter的策略
  1. import os
  2. import glob
  3. import hashlib
  4. from typing import List
  5. from functools import partial
  6. from tqdm import tqdm
  7. from multiprocessing import Pool
  8. from langchain.document_loaders import (
  9.     CSVLoader,
  10.     EverNoteLoader,
  11.     PDFMinerLoader,
  12.     TextLoader,
  13.     UnstructuredEmailLoader,
  14.     UnstructuredEPubLoader,
  15.     UnstructuredHTMLLoader,
  16.     UnstructuredMarkdownLoader,
  17.     UnstructuredODTLoader,
  18.     UnstructuredPowerPointLoader,
  19.     UnstructuredWordDocumentLoader,
  20.     UnstructuredExcelLoader,
  21. )
  22. from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
  23. from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
  24. from langchain.docstore.document import Document
  25. CONFIG = {
  26.     "doc_source": "./docs",  # 需要向量化的文档
  27.     "embedding_model": "hugging_models/text2vec-base-chinese_config",  # embeding模型
  28.     "db_source": "./db",  # 向量化数据库
  29.     "db_type": "faiss",  #
  30.     "chunk_size": 200,  # 块词量
  31.     "chunk_overlap": 20,  # 交集范围
  32.     "k": 3,  # 查询文档量
  33.     "merge_rows": 5,  # 合并表格的行数,
  34.     "hash_file_path": "hash_file.txt",  #
  35. }
  36. # 设置目录和embedding基础变量
  37. source_directory = CONFIG["doc_source"]
  38. embeddings_model_name = CONFIG["embedding_model"]
  39. chunk_size = CONFIG["chunk_size"]
  40. chunk_overlap = CONFIG["chunk_overlap"]
  41. output_dir = CONFIG["db_source"]
  42. k = CONFIG["k"]
  43. merge_rows = CONFIG["merge_rows"]
  44. hash_file_path = CONFIG["hash_file_path"]
  45. db_type = CONFIG["db_type"]
  46. # Custom document loaders 自定义文档加载
  47. class MyElmLoader(UnstructuredEmailLoader):
  48.     def load(self) -> List[Document]:
  49.         """Wrapper adding fallback for elm without html"""
  50.         try:
  51.             try:
  52.                 doc = UnstructuredEmailLoader.load(self)
  53.             except ValueError as e:
  54.                 if "text/html content not found in email" in str(e):
  55.                     # Try plain text
  56.                     self.unstructured_kwargs["content_source"] = "text/plain"
  57.                     doc = UnstructuredEmailLoader.load(self)
  58.                 else:
  59.                     raise
  60.         except Exception as e:
  61.             # Add file_path to exception message
  62.             raise type(e)(f"{self.file_path}: {e}") from e
  63.         return doc
  64. # Map file extensions to document loaders and their arguments
  65. # GBK2312 GB18030
  66. LOADER_MAPPING = {
  67.     ".csv": (CSVLoader, {}),
  68.     ".doc": (UnstructuredWordDocumentLoader, {}),
  69.     ".docx": (UnstructuredWordDocumentLoader, {}),
  70.     ".enex": (EverNoteLoader, {}),
  71.     ".eml": (MyElmLoader, {}),
  72.     ".epub": (UnstructuredEPubLoader, {}),
  73.     ".html": (UnstructuredHTMLLoader, {}),
  74.     ".md": (UnstructuredMarkdownLoader, {}),
  75.     ".odt": (UnstructuredODTLoader, {}),
  76.     ".pdf": (PDFMinerLoader, {}),
  77.     ".ppt": (UnstructuredPowerPointLoader, {}),
  78.     ".pptx": (UnstructuredPowerPointLoader, {}),
  79.     ".txt": (TextLoader, {"encoding": "utf8"}),
  80.     ".xls": (UnstructuredExcelLoader, {}),
  81.     ".xlsx": (UnstructuredExcelLoader, {}),
  82. }
  83. def read_hash_file(path="hash_file.txt"):
  84.     hash_file_list = []
  85.     if os.path.exists(path):
  86.         with open(path, "r") as f:
  87.             hash_file_list = [i.strip() for i in f.readlines()]
  88.     return hash_file_list
  89. def save_hash_file(hash_list, path="hash_file.txt"):
  90.     with open(path, "w") as f:
  91.         f.write("\n".join(hash_list))
  92. def get_hash_from_file(path):
  93.     with open(path, "rb") as f:
  94.         readable_hash = hashlib.md5(f.read()).hexdigest()
  95.     return readable_hash
  96. def load_single_document(
  97.     file_path: str, splitter: TextSplitter, merge_rows: int
  98. ) -> List[Document]:
  99.     ext = "." + file_path.rsplit(".", 1)[-1]
  100.     if ext in LOADER_MAPPING:
  101.         loader_class, loader_args = LOADER_MAPPING[ext]
  102.         loader = loader_class(file_path, **loader_args)
  103.         docs = loader.load()
  104.         # 针对不同的文件类型分别进行处理
  105.         if not file_path.endswith((".xlsx", "xls", "csv")):
  106.             # 合并一个文件中的所有page_content
  107.             tmp = [i.page_content for i in docs]
  108.             docs = Document(
  109.                 "".join(tmp).strip(),
  110.                 metadata={"source": docs[0].metadata["source"], "pages": len(tmp)},
  111.             )
  112.             # 进行split
  113.             docs = splitter.split_documents([docs])
  114.         else:
  115.             # 表格数据,合并多个行
  116.             merge_n = len(docs) // merge_rows + bool(len(docs) % merge_rows)
  117.             _docs = []
  118.             for i in range(merge_n):
  119.                 tmp = "\n\n".join(
  120.                     [
  121.                         d.page_content
  122.                         for d in docs[i * merge_rows : (i + 1) * merge_rows]
  123.                     ]
  124.                 )
  125.                 _docs.append(Document(tmp, metadata=dict(source=docs[0]["source"])))
  126.             docs = _docs
  127.         return docs
  128.     raise ValueError(f"Unsupported file extension '{ext}'")
  129. def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
  130.     """
  131.     Loads all documents from the source documents directory, ignoring specified files
  132.     """
  133.     all_files = []
  134.     for ext in LOADER_MAPPING:
  135.         all_files.extend(
  136.             glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
  137.         )
  138.     filtered_files = [
  139.         file_path for file_path in all_files if file_path not in ignored_files
  140.     ]
  141.     # hash filter
  142.     if os.path.exists(hash_file_path):
  143.         hash_file_list = read_hash_file(hash_file_path)
  144.         if hash_file_list:
  145.             tmp = []
  146.             for file in filtered_files:
  147.                 hash = get_hash_from_file(file)
  148.                 if hash not in hash_file_list:
  149.                     tmp.append(file)
  150.                     hash_file_list(hash)
  151.             filtered_files = tmp
  152.             save_hash_file(hash_file_list, hash_file_path)
  153.     # splitter
  154.     splitter = RecursiveCharacterTextSplitter(
  155.         chunk_size=chunk_size, chunk_overlap=chunk_overlap
  156.     )
  157.     load_document = partial(
  158.         load_single_document, splitter=splitter, merge_rows=merge_rows
  159.     )
  160.     # load
  161.     with Pool(processes=os.cpu_count()) as pool:
  162.         results = []
  163.         with tqdm(
  164.             total=len(filtered_files), desc="Loading new documents", ncols=80
  165.         ) as pbar:
  166.             for i, docs in enumerate(
  167.                 pool.imap_unordered(load_document, filtered_files)
  168.             ):
  169.                 results.extend(docs)
  170.                 pbar.update()
  171.     return results
  172. def process_documents(ignored_files: List[str] = []) -> List[Document]:
  173.     """
  174.     Load documents and split in chunks
  175.     """
  176.     print(f"Loading documents from {source_directory}")
  177.     documents = load_documents(source_directory, ignored_files)
  178.     if not documents:
  179.         print("No new documents to load")
  180.         exit(0)
  181.     print(
  182.         f"Loaded {len(documents)} new documents from {source_directory}."
  183.         f"\nSplit into {len(documents)} chunks of text (max. {chunk_size} tokens each)"
  184.     )
  185.     return documents
  186. def main():
  187.     # Create embeddings
  188.     # print(torch.cuda.is_available())
  189.     # Create and store locally vectorstore
  190.     print("Creating new vectorstore")
  191.     documents = process_documents()
  192.     print(f"Creating embeddings. May take some minutes...")
  193.     embedding_function = SentenceTransformerEmbeddings(model_name=embeddings_model_name)
  194.     if db_type == "chroma":
  195.         from langchain.vectorstores import Chroma
  196.         db = Chroma.from_documents(
  197.             documents, embedding_function, persist_directory=output_dir
  198.         )
  199.         db.persist()
  200.         db = None
  201.     elif db_type == "faiss":
  202.         from langchain.vectorstores import FAISS
  203.         print("创建新数据库")
  204.         db = FAISS.from_documents(documents, embedding_function)
  205.         # 读取之前的db data(GPU版本的不支持)
  206.         if os.path.exists(output_dir):
  207.             try:
  208.                 print("读取旧数据库")
  209.                 old_db = FAISS.load_local(output_dir, embedding_function)
  210.                 print("融合新旧数据库")
  211.                 db.merge_from(old_db)
  212.             except Exception as e:
  213.                 print(e)
  214.         print("保存")
  215.         db.save_local(output_dir)
  216.         db = None
  217.     else:
  218.         raise NotImplementedError(f'未定义数据库 {db_type} 的实现.')
  219. if __name__ == "__main__":
  220.     main()
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

大连全瓷种植牙齿制作中心

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表