大连全瓷种植牙齿制作中心 发表于 2024-7-11 16:41:07

【langchain 创建向量数据库非常美满的代码】

 - 支持faiss chroma两种数据库(faiss-cpu 支持旧数据库的合并)
- 支持制止重复文件embedding(hash)
- 支持浩繁文件格式
- 支持huggingface的embedding模型
- 优化了切分chunk的策略
- 支持多线程处理惩罚文件

   修改自langchain-chatchat, 增加了一些功能, 优化了splitter的策略
import os
import glob
import hashlib
from typing import List
from functools import partial
from tqdm import tqdm

from multiprocessing import Pool
from langchain.document_loaders import (
    CSVLoader,
    EverNoteLoader,
    PDFMinerLoader,
    TextLoader,
    UnstructuredEmailLoader,
    UnstructuredEPubLoader,
    UnstructuredHTMLLoader,
    UnstructuredMarkdownLoader,
    UnstructuredODTLoader,
    UnstructuredPowerPointLoader,
    UnstructuredWordDocumentLoader,
    UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.docstore.document import Document




CONFIG = {
    "doc_source": "./docs",# 需要向量化的文档
    "embedding_model": "hugging_models/text2vec-base-chinese_config",# embeding模型
    "db_source": "./db",# 向量化数据库
    "db_type": "faiss",#
    "chunk_size": 200,# 块词量
    "chunk_overlap": 20,# 交集范围
    "k": 3,# 查询文档量
    "merge_rows": 5,# 合并表格的行数,
    "hash_file_path": "hash_file.txt",#
}


# 设置目录和embedding基础变量
source_directory = CONFIG["doc_source"]
embeddings_model_name = CONFIG["embedding_model"]
chunk_size = CONFIG["chunk_size"]
chunk_overlap = CONFIG["chunk_overlap"]
output_dir = CONFIG["db_source"]
k = CONFIG["k"]
merge_rows = CONFIG["merge_rows"]
hash_file_path = CONFIG["hash_file_path"]
db_type = CONFIG["db_type"]


# Custom document loaders 自定义文档加载
class MyElmLoader(UnstructuredEmailLoader):
    def load(self) -> List:
      """Wrapper adding fallback for elm without html"""
      try:
            try:
                doc = UnstructuredEmailLoader.load(self)
            except ValueError as e:
                if "text/html content not found in email" in str(e):
                  # Try plain text
                  self.unstructured_kwargs["content_source"] = "text/plain"
                  doc = UnstructuredEmailLoader.load(self)
                else:
                  raise
      except Exception as e:
            # Add file_path to exception message
            raise type(e)(f"{self.file_path}: {e}") from e

      return doc


# Map file extensions to document loaders and their arguments
# GBK2312 GB18030
LOADER_MAPPING = {
    ".csv": (CSVLoader, {}),
    ".doc": (UnstructuredWordDocumentLoader, {}),
    ".docx": (UnstructuredWordDocumentLoader, {}),
    ".enex": (EverNoteLoader, {}),
    ".eml": (MyElmLoader, {}),
    ".epub": (UnstructuredEPubLoader, {}),
    ".html": (UnstructuredHTMLLoader, {}),
    ".md": (UnstructuredMarkdownLoader, {}),
    ".odt": (UnstructuredODTLoader, {}),
    ".pdf": (PDFMinerLoader, {}),
    ".ppt": (UnstructuredPowerPointLoader, {}),
    ".pptx": (UnstructuredPowerPointLoader, {}),
    ".txt": (TextLoader, {"encoding": "utf8"}),
    ".xls": (UnstructuredExcelLoader, {}),
    ".xlsx": (UnstructuredExcelLoader, {}),
}


def read_hash_file(path="hash_file.txt"):
    hash_file_list = []
    if os.path.exists(path):
      with open(path, "r") as f:
            hash_file_list =
    return hash_file_list


def save_hash_file(hash_list, path="hash_file.txt"):
    with open(path, "w") as f:
      f.write("\n".join(hash_list))


def get_hash_from_file(path):
    with open(path, "rb") as f:
      readable_hash = hashlib.md5(f.read()).hexdigest()
    return readable_hash


def load_single_document(
    file_path: str, splitter: TextSplitter, merge_rows: int
) -> List:
    ext = "." + file_path.rsplit(".", 1)[-1]
    if ext in LOADER_MAPPING:
      loader_class, loader_args = LOADER_MAPPING
      loader = loader_class(file_path, **loader_args)
      docs = loader.load()

      # 针对不同的文件类型分别进行处理
      if not file_path.endswith((".xlsx", "xls", "csv")):
            # 合并一个文件中的所有page_content
            tmp =
            docs = Document(
                "".join(tmp).strip(),
                metadata={"source": docs.metadata["source"], "pages": len(tmp)},
            )
            # 进行split
            docs = splitter.split_documents()
      else:
            # 表格数据,合并多个行
            merge_n = len(docs) // merge_rows + bool(len(docs) % merge_rows)
            _docs = []
            for i in range(merge_n):
                tmp = "\n\n".join(
                  [
                        d.page_content
                        for d in docs
                  ]
                )
                _docs.append(Document(tmp, metadata=dict(source=docs["source"])))
            docs = _docs

      return docs

    raise ValueError(f"Unsupported file extension '{ext}'")


def load_documents(source_dir: str, ignored_files: List = []) -> List:
    """
    Loads all documents from the source documents directory, ignoring specified files
    """
    all_files = []
    for ext in LOADER_MAPPING:
      all_files.extend(
            glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
      )
    filtered_files = [
      file_path for file_path in all_files if file_path not in ignored_files
    ]

    # hash filter
    if os.path.exists(hash_file_path):
      hash_file_list = read_hash_file(hash_file_path)
      if hash_file_list:
            tmp = []
            for file in filtered_files:
                hash = get_hash_from_file(file)
                if hash not in hash_file_list:
                  tmp.append(file)
                  hash_file_list(hash)
            filtered_files = tmp

            save_hash_file(hash_file_list, hash_file_path)

    # splitter
    splitter = RecursiveCharacterTextSplitter(
      chunk_size=chunk_size, chunk_overlap=chunk_overlap
    )
    load_document = partial(
      load_single_document, splitter=splitter, merge_rows=merge_rows
    )

    # load
    with Pool(processes=os.cpu_count()) as pool:
      results = []
      with tqdm(
            total=len(filtered_files), desc="Loading new documents", ncols=80
      ) as pbar:
            for i, docs in enumerate(
                pool.imap_unordered(load_document, filtered_files)
            ):
                results.extend(docs)
                pbar.update()

    return results


def process_documents(ignored_files: List = []) -> List:
    """
    Load documents and split in chunks
    """
    print(f"Loading documents from {source_directory}")
    documents = load_documents(source_directory, ignored_files)
    if not documents:
      print("No new documents to load")
      exit(0)
    print(
      f"Loaded {len(documents)} new documents from {source_directory}."
      f"\nSplit into {len(documents)} chunks of text (max. {chunk_size} tokens each)"
    )
    return documents


def main():
    # Create embeddings
    # print(torch.cuda.is_available())
    # Create and store locally vectorstore
    print("Creating new vectorstore")
    documents = process_documents()
    print(f"Creating embeddings. May take some minutes...")
    embedding_function = SentenceTransformerEmbeddings(model_name=embeddings_model_name)

    if db_type == "chroma":
      from langchain.vectorstores import Chroma

      db = Chroma.from_documents(
            documents, embedding_function, persist_directory=output_dir
      )
      db.persist()
      db = None
    elif db_type == "faiss":
      from langchain.vectorstores import FAISS

      print("创建新数据库")
      db = FAISS.from_documents(documents, embedding_function)

      # 读取之前的db data(GPU版本的不支持)
      if os.path.exists(output_dir):
            try:
                print("读取旧数据库")
                old_db = FAISS.load_local(output_dir, embedding_function)

                print("融合新旧数据库")
                db.merge_from(old_db)
            except Exception as e:
                print(e)

      print("保存")
      db.save_local(output_dir)
      db = None
    else:
      raise NotImplementedError(f'未定义数据库 {db_type} 的实现.')

if __name__ == "__main__":
    main()

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: 【langchain 创建向量数据库非常美满的代码】