- 支持faiss chroma两种数据库(faiss-cpu 支持旧数据库的合并)
- 支持制止重复文件embedding(hash)
- 支持浩繁文件格式
- 支持huggingface的embedding模型
- 优化了切分chunk的策略
- 支持多线程处理惩罚文件
修改自langchain-chatchat, 增加了一些功能, 优化了splitter的策略
- import os
- import glob
- import hashlib
- from typing import List
- from functools import partial
- from tqdm import tqdm
- from multiprocessing import Pool
- from langchain.document_loaders import (
- CSVLoader,
- EverNoteLoader,
- PDFMinerLoader,
- TextLoader,
- UnstructuredEmailLoader,
- UnstructuredEPubLoader,
- UnstructuredHTMLLoader,
- UnstructuredMarkdownLoader,
- UnstructuredODTLoader,
- UnstructuredPowerPointLoader,
- UnstructuredWordDocumentLoader,
- UnstructuredExcelLoader,
- )
- from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
- from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
- from langchain.docstore.document import Document
- CONFIG = {
- "doc_source": "./docs", # 需要向量化的文档
- "embedding_model": "hugging_models/text2vec-base-chinese_config", # embeding模型
- "db_source": "./db", # 向量化数据库
- "db_type": "faiss", #
- "chunk_size": 200, # 块词量
- "chunk_overlap": 20, # 交集范围
- "k": 3, # 查询文档量
- "merge_rows": 5, # 合并表格的行数,
- "hash_file_path": "hash_file.txt", #
- }
- # 设置目录和embedding基础变量
- source_directory = CONFIG["doc_source"]
- embeddings_model_name = CONFIG["embedding_model"]
- chunk_size = CONFIG["chunk_size"]
- chunk_overlap = CONFIG["chunk_overlap"]
- output_dir = CONFIG["db_source"]
- k = CONFIG["k"]
- merge_rows = CONFIG["merge_rows"]
- hash_file_path = CONFIG["hash_file_path"]
- db_type = CONFIG["db_type"]
- # Custom document loaders 自定义文档加载
- class MyElmLoader(UnstructuredEmailLoader):
- def load(self) -> List[Document]:
- """Wrapper adding fallback for elm without html"""
- try:
- try:
- doc = UnstructuredEmailLoader.load(self)
- except ValueError as e:
- if "text/html content not found in email" in str(e):
- # Try plain text
- self.unstructured_kwargs["content_source"] = "text/plain"
- doc = UnstructuredEmailLoader.load(self)
- else:
- raise
- except Exception as e:
- # Add file_path to exception message
- raise type(e)(f"{self.file_path}: {e}") from e
- return doc
- # Map file extensions to document loaders and their arguments
- # GBK2312 GB18030
- LOADER_MAPPING = {
- ".csv": (CSVLoader, {}),
- ".doc": (UnstructuredWordDocumentLoader, {}),
- ".docx": (UnstructuredWordDocumentLoader, {}),
- ".enex": (EverNoteLoader, {}),
- ".eml": (MyElmLoader, {}),
- ".epub": (UnstructuredEPubLoader, {}),
- ".html": (UnstructuredHTMLLoader, {}),
- ".md": (UnstructuredMarkdownLoader, {}),
- ".odt": (UnstructuredODTLoader, {}),
- ".pdf": (PDFMinerLoader, {}),
- ".ppt": (UnstructuredPowerPointLoader, {}),
- ".pptx": (UnstructuredPowerPointLoader, {}),
- ".txt": (TextLoader, {"encoding": "utf8"}),
- ".xls": (UnstructuredExcelLoader, {}),
- ".xlsx": (UnstructuredExcelLoader, {}),
- }
- def read_hash_file(path="hash_file.txt"):
- hash_file_list = []
- if os.path.exists(path):
- with open(path, "r") as f:
- hash_file_list = [i.strip() for i in f.readlines()]
- return hash_file_list
- def save_hash_file(hash_list, path="hash_file.txt"):
- with open(path, "w") as f:
- f.write("\n".join(hash_list))
- def get_hash_from_file(path):
- with open(path, "rb") as f:
- readable_hash = hashlib.md5(f.read()).hexdigest()
- return readable_hash
- def load_single_document(
- file_path: str, splitter: TextSplitter, merge_rows: int
- ) -> List[Document]:
- ext = "." + file_path.rsplit(".", 1)[-1]
- if ext in LOADER_MAPPING:
- loader_class, loader_args = LOADER_MAPPING[ext]
- loader = loader_class(file_path, **loader_args)
- docs = loader.load()
- # 针对不同的文件类型分别进行处理
- if not file_path.endswith((".xlsx", "xls", "csv")):
- # 合并一个文件中的所有page_content
- tmp = [i.page_content for i in docs]
- docs = Document(
- "".join(tmp).strip(),
- metadata={"source": docs[0].metadata["source"], "pages": len(tmp)},
- )
- # 进行split
- docs = splitter.split_documents([docs])
- else:
- # 表格数据,合并多个行
- merge_n = len(docs) // merge_rows + bool(len(docs) % merge_rows)
- _docs = []
- for i in range(merge_n):
- tmp = "\n\n".join(
- [
- d.page_content
- for d in docs[i * merge_rows : (i + 1) * merge_rows]
- ]
- )
- _docs.append(Document(tmp, metadata=dict(source=docs[0]["source"])))
- docs = _docs
- return docs
- raise ValueError(f"Unsupported file extension '{ext}'")
- def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
- """
- Loads all documents from the source documents directory, ignoring specified files
- """
- all_files = []
- for ext in LOADER_MAPPING:
- all_files.extend(
- glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
- )
- filtered_files = [
- file_path for file_path in all_files if file_path not in ignored_files
- ]
- # hash filter
- if os.path.exists(hash_file_path):
- hash_file_list = read_hash_file(hash_file_path)
- if hash_file_list:
- tmp = []
- for file in filtered_files:
- hash = get_hash_from_file(file)
- if hash not in hash_file_list:
- tmp.append(file)
- hash_file_list(hash)
- filtered_files = tmp
- save_hash_file(hash_file_list, hash_file_path)
- # splitter
- splitter = RecursiveCharacterTextSplitter(
- chunk_size=chunk_size, chunk_overlap=chunk_overlap
- )
- load_document = partial(
- load_single_document, splitter=splitter, merge_rows=merge_rows
- )
- # load
- with Pool(processes=os.cpu_count()) as pool:
- results = []
- with tqdm(
- total=len(filtered_files), desc="Loading new documents", ncols=80
- ) as pbar:
- for i, docs in enumerate(
- pool.imap_unordered(load_document, filtered_files)
- ):
- results.extend(docs)
- pbar.update()
- return results
- def process_documents(ignored_files: List[str] = []) -> List[Document]:
- """
- Load documents and split in chunks
- """
- print(f"Loading documents from {source_directory}")
- documents = load_documents(source_directory, ignored_files)
- if not documents:
- print("No new documents to load")
- exit(0)
- print(
- f"Loaded {len(documents)} new documents from {source_directory}."
- f"\nSplit into {len(documents)} chunks of text (max. {chunk_size} tokens each)"
- )
- return documents
- def main():
- # Create embeddings
- # print(torch.cuda.is_available())
- # Create and store locally vectorstore
- print("Creating new vectorstore")
- documents = process_documents()
- print(f"Creating embeddings. May take some minutes...")
- embedding_function = SentenceTransformerEmbeddings(model_name=embeddings_model_name)
- if db_type == "chroma":
- from langchain.vectorstores import Chroma
- db = Chroma.from_documents(
- documents, embedding_function, persist_directory=output_dir
- )
- db.persist()
- db = None
- elif db_type == "faiss":
- from langchain.vectorstores import FAISS
- print("创建新数据库")
- db = FAISS.from_documents(documents, embedding_function)
- # 读取之前的db data(GPU版本的不支持)
- if os.path.exists(output_dir):
- try:
- print("读取旧数据库")
- old_db = FAISS.load_local(output_dir, embedding_function)
- print("融合新旧数据库")
- db.merge_from(old_db)
- except Exception as e:
- print(e)
- print("保存")
- db.save_local(output_dir)
- db = None
- else:
- raise NotImplementedError(f'未定义数据库 {db_type} 的实现.')
- if __name__ == "__main__":
- main()
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |