【langchain 创建向量数据库非常美满的代码】
- 支持faiss chroma两种数据库(faiss-cpu 支持旧数据库的合并)- 支持制止重复文件embedding(hash)
- 支持浩繁文件格式
- 支持huggingface的embedding模型
- 优化了切分chunk的策略
- 支持多线程处理惩罚文件
修改自langchain-chatchat, 增加了一些功能, 优化了splitter的策略
import os
import glob
import hashlib
from typing import List
from functools import partial
from tqdm import tqdm
from multiprocessing import Pool
from langchain.document_loaders import (
CSVLoader,
EverNoteLoader,
PDFMinerLoader,
TextLoader,
UnstructuredEmailLoader,
UnstructuredEPubLoader,
UnstructuredHTMLLoader,
UnstructuredMarkdownLoader,
UnstructuredODTLoader,
UnstructuredPowerPointLoader,
UnstructuredWordDocumentLoader,
UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.docstore.document import Document
CONFIG = {
"doc_source": "./docs",# 需要向量化的文档
"embedding_model": "hugging_models/text2vec-base-chinese_config",# embeding模型
"db_source": "./db",# 向量化数据库
"db_type": "faiss",#
"chunk_size": 200,# 块词量
"chunk_overlap": 20,# 交集范围
"k": 3,# 查询文档量
"merge_rows": 5,# 合并表格的行数,
"hash_file_path": "hash_file.txt",#
}
# 设置目录和embedding基础变量
source_directory = CONFIG["doc_source"]
embeddings_model_name = CONFIG["embedding_model"]
chunk_size = CONFIG["chunk_size"]
chunk_overlap = CONFIG["chunk_overlap"]
output_dir = CONFIG["db_source"]
k = CONFIG["k"]
merge_rows = CONFIG["merge_rows"]
hash_file_path = CONFIG["hash_file_path"]
db_type = CONFIG["db_type"]
# Custom document loaders 自定义文档加载
class MyElmLoader(UnstructuredEmailLoader):
def load(self) -> List:
"""Wrapper adding fallback for elm without html"""
try:
try:
doc = UnstructuredEmailLoader.load(self)
except ValueError as e:
if "text/html content not found in email" in str(e):
# Try plain text
self.unstructured_kwargs["content_source"] = "text/plain"
doc = UnstructuredEmailLoader.load(self)
else:
raise
except Exception as e:
# Add file_path to exception message
raise type(e)(f"{self.file_path}: {e}") from e
return doc
# Map file extensions to document loaders and their arguments
# GBK2312 GB18030
LOADER_MAPPING = {
".csv": (CSVLoader, {}),
".doc": (UnstructuredWordDocumentLoader, {}),
".docx": (UnstructuredWordDocumentLoader, {}),
".enex": (EverNoteLoader, {}),
".eml": (MyElmLoader, {}),
".epub": (UnstructuredEPubLoader, {}),
".html": (UnstructuredHTMLLoader, {}),
".md": (UnstructuredMarkdownLoader, {}),
".odt": (UnstructuredODTLoader, {}),
".pdf": (PDFMinerLoader, {}),
".ppt": (UnstructuredPowerPointLoader, {}),
".pptx": (UnstructuredPowerPointLoader, {}),
".txt": (TextLoader, {"encoding": "utf8"}),
".xls": (UnstructuredExcelLoader, {}),
".xlsx": (UnstructuredExcelLoader, {}),
}
def read_hash_file(path="hash_file.txt"):
hash_file_list = []
if os.path.exists(path):
with open(path, "r") as f:
hash_file_list =
return hash_file_list
def save_hash_file(hash_list, path="hash_file.txt"):
with open(path, "w") as f:
f.write("\n".join(hash_list))
def get_hash_from_file(path):
with open(path, "rb") as f:
readable_hash = hashlib.md5(f.read()).hexdigest()
return readable_hash
def load_single_document(
file_path: str, splitter: TextSplitter, merge_rows: int
) -> List:
ext = "." + file_path.rsplit(".", 1)[-1]
if ext in LOADER_MAPPING:
loader_class, loader_args = LOADER_MAPPING
loader = loader_class(file_path, **loader_args)
docs = loader.load()
# 针对不同的文件类型分别进行处理
if not file_path.endswith((".xlsx", "xls", "csv")):
# 合并一个文件中的所有page_content
tmp =
docs = Document(
"".join(tmp).strip(),
metadata={"source": docs.metadata["source"], "pages": len(tmp)},
)
# 进行split
docs = splitter.split_documents()
else:
# 表格数据,合并多个行
merge_n = len(docs) // merge_rows + bool(len(docs) % merge_rows)
_docs = []
for i in range(merge_n):
tmp = "\n\n".join(
[
d.page_content
for d in docs
]
)
_docs.append(Document(tmp, metadata=dict(source=docs["source"])))
docs = _docs
return docs
raise ValueError(f"Unsupported file extension '{ext}'")
def load_documents(source_dir: str, ignored_files: List = []) -> List:
"""
Loads all documents from the source documents directory, ignoring specified files
"""
all_files = []
for ext in LOADER_MAPPING:
all_files.extend(
glob.glob(os.path.join(source_dir, f"**/*{ext}"), recursive=True)
)
filtered_files = [
file_path for file_path in all_files if file_path not in ignored_files
]
# hash filter
if os.path.exists(hash_file_path):
hash_file_list = read_hash_file(hash_file_path)
if hash_file_list:
tmp = []
for file in filtered_files:
hash = get_hash_from_file(file)
if hash not in hash_file_list:
tmp.append(file)
hash_file_list(hash)
filtered_files = tmp
save_hash_file(hash_file_list, hash_file_path)
# splitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
load_document = partial(
load_single_document, splitter=splitter, merge_rows=merge_rows
)
# load
with Pool(processes=os.cpu_count()) as pool:
results = []
with tqdm(
total=len(filtered_files), desc="Loading new documents", ncols=80
) as pbar:
for i, docs in enumerate(
pool.imap_unordered(load_document, filtered_files)
):
results.extend(docs)
pbar.update()
return results
def process_documents(ignored_files: List = []) -> List:
"""
Load documents and split in chunks
"""
print(f"Loading documents from {source_directory}")
documents = load_documents(source_directory, ignored_files)
if not documents:
print("No new documents to load")
exit(0)
print(
f"Loaded {len(documents)} new documents from {source_directory}."
f"\nSplit into {len(documents)} chunks of text (max. {chunk_size} tokens each)"
)
return documents
def main():
# Create embeddings
# print(torch.cuda.is_available())
# Create and store locally vectorstore
print("Creating new vectorstore")
documents = process_documents()
print(f"Creating embeddings. May take some minutes...")
embedding_function = SentenceTransformerEmbeddings(model_name=embeddings_model_name)
if db_type == "chroma":
from langchain.vectorstores import Chroma
db = Chroma.from_documents(
documents, embedding_function, persist_directory=output_dir
)
db.persist()
db = None
elif db_type == "faiss":
from langchain.vectorstores import FAISS
print("创建新数据库")
db = FAISS.from_documents(documents, embedding_function)
# 读取之前的db data(GPU版本的不支持)
if os.path.exists(output_dir):
try:
print("读取旧数据库")
old_db = FAISS.load_local(output_dir, embedding_function)
print("融合新旧数据库")
db.merge_from(old_db)
except Exception as e:
print(e)
print("保存")
db.save_local(output_dir)
db = None
else:
raise NotImplementedError(f'未定义数据库 {db_type} 的实现.')
if __name__ == "__main__":
main()
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页:
[1]