milvus+flask山寨《从零构建向量数据库》第7章case1

打印 上一主题 下一主题

主题 1883|帖子 1883|积分 5649

常规练手,图片搜刮山寨版。拜读罗云大佬著作,结果只有操纵层的东西可以上上手。
书中是自己写的向量数据库,这边直接用python拼个现成的milvus向量数据库。
1. 创建一个向量数据库以及对应的相应数据表:
  1. # Milvus Setup Arguments
  2. COLLECTION_NAME = 'animal_search'
  3. DIMENSION = 2048
  4. MILVUS_HOST = "localhost"
  5. MILVUS_PORT = "19530"
  6. # Inference Arguments
  7. BATCH_SIZE = 128
  8. from pymilvus import connections
  9. # Connect to the instance
  10. connections.connect(host=MILVUS_HOST,port=MILVUS_PORT)
  11. from pymilvus import utility
  12. # Remove any previous collection with the same name
  13. if utility.has_collection(COLLECTION_NAME):
  14.     utility.drop_collection(COLLECTION_NAME)
  15. #创建保存ID、图片文件路径及Embeddings的Collection。
  16. from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
  17. fields = [
  18.         FieldSchema(name='id',dtype=DataType.INT64, is_primary=True, auto_id=True),
  19.         FieldSchema(name='filepath', dtype=DataType.VARCHAR,max_length=200),
  20.         FieldSchema(name='image_embedding',dtype=DataType.FLOAT_VECTOR,dim=DIMENSION)
  21.         ]
  22. schema = CollectionSchema(fields=fields)
  23. collection = Collection(name=COLLECTION_NAME, schema=schema)
  24. index_params = {
  25.         'metric_type':'L2',
  26.         'index_type': "IVF_FLAT",
  27.         'params':{'nlist':16384}
  28. }
  29. collection.create_index(field_name="image_embedding",index_params=index_params)
  30. collection.load()
复制代码
2. 写一堆图片进去存着,向量其实就是各种像素间的维度特征,
  1. # Milvus Setup Arguments
  2. COLLECTION_NAME = 'animal_search'
  3. DIMENSION = 2048
  4. MILVUS_HOST = "localhost"
  5. MILVUS_PORT = "19530"
  6. # Inference Arguments
  7. BATCH_SIZE = 128
  8. from pymilvus import connections
  9. # Connect to the instance
  10. connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
  11. import glob
  12. paths = glob.glob('/mcm/vectorDB_training/animals_db/*',recursive=True)
  13. #分批预处理数据
  14. import torch
  15. # Load the embedding model with the last layer removed
  16. model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
  17. model = torch.nn.Sequential(*(list(model.children())[:-1]))
  18. model.eval()
  19. from torchvision import transforms
  20. # Preprocessing for images
  21. preprocess = transforms.Compose([
  22.     transforms.Resize(256),
  23.     transforms.CenterCrop(224),
  24.     transforms.ToTensor(),
  25.     transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
  26. ])
  27. #插入数据
  28. from PIL import Image
  29. from tqdm import tqdm
  30. # Embed function that embeds the batch and inserts it
  31. def embed(data):
  32.     from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
  33.     fields = [
  34.         FieldSchema(name='id',dtype=DataType.INT64, is_primary=True, auto_id=True),
  35.         FieldSchema(name='filepath', dtype=DataType.VARCHAR,max_length=200),
  36.         FieldSchema(name='image_embedding',dtype=DataType.FLOAT_VECTOR,dim=DIMENSION)
  37.         ]
  38.     schema = CollectionSchema(fields=fields)
  39.     collection = Collection(name=COLLECTION_NAME, schema=schema)
  40.     with torch.no_grad():
  41.         output = model(torch.stack(data[0])).squeeze()
  42.         collection.insert([data[1],output.tolist()])
  43.     collection.flush()
  44. data_batch = [[],[]]
  45. # Read the images into batches for embedding and insertion
  46. for path in tqdm(paths):
  47.     im = Image.open(path).convert('RGB')
  48.     data_batch[0].append(preprocess(im))
  49.     data_batch[1].append(path)
  50.     if len(data_batch[0]) % BATCH_SIZE == 0:
  51.         embed(data_batch)
  52.         data_batch = [[],[]]
  53. # Embed and insert the remainder
  54. if len(data_batch[0]) != 0:
  55.     embed(data_batch)
复制代码
3. 向量化图片的函数要单独拎出来,做搜刮功能的时候用它。
  1. import torch
  2. import torchvision.transforms as transforms
  3. from torchvision.models import resnet50
  4. from PIL import Image
  5. def extract_features(image_path):
  6.     # 加载预训练的 ResNet-50 模型
  7.     model = resnet50(pretrained=True)
  8.     model = torch.nn.Sequential(*list(model.children())[:-1])  #移除fc层,不移除,向量最后就是1000层,而不是2048
  9.     model.eval()
  10.     # 图像预处理
  11.     preprocess = transforms.Compose([
  12.         transforms.Resize(256),
  13.         transforms.CenterCrop(224),
  14.         transforms.ToTensor(),
  15.         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
  16.     ])
  17.     # 读取图像
  18.     img = Image.open(image_path)
  19.     img_t = preprocess(img)
  20.     batch_t = torch.unsqueeze(img_t, 0)
  21.     # 提取特征
  22.     with torch.no_grad():
  23.         out = model(batch_t)
  24.     # 将特征向量转换为一维数组并返回
  25.     return out.flatten().numpy()
复制代码
4. 用flask做的界面
  1. from flask import Flask,request,jsonify
  2. from flask import render_template
  3. from image_eb import extract_features
  4. #from pymilvus import connections
  5. from pymilvus import MilvusClient
  6. import logging
  7. import os
  8. import shutil
  9. MILVUS_HOST = "localhost"
  10. MILVUS_PORT = "19530"
  11. COLLECTION_NAME = 'animal_search'
  12. TOP_K = 3
  13. app = Flask(__name__)
  14. milvus_client = MilvusClient(uri="http://localhost:19530")
  15. @app.route("/")
  16. def index():
  17.     return render_template("index.html")
  18. @app.route("/upload",methods=["POST"])
  19. def upload_image():
  20.     image_file = request.files["image"]
  21.     image_id_str = request.form.get("image_id")
  22.     data = []
  23.     #检查image_id是否存在。
  24.     if not image_id_str:
  25.         return jsonify({"message": "Image ID is required"}),400
  26.     #image id转化为整型
  27.     try:
  28.         image_id = int(image_id_str)
  29.         data.append(image_id)
  30.     except ValueError:
  31.         return jsonify({"message": "Invalid image ID. It must be an integer"}),400
  32.     filename = image_file.filename
  33.     image_path = os.path.join("static/images",image_id_str)
  34.     image_file.save(image_path)
  35.     image_features = extract_features(image_path)
  36.     data.append(image_features)
  37.     data_dict = dict(filepath=image_path,image_embedding=data[1])
  38.     #更新数据库中记录
  39.     milvus_client.insert(collection_name=COLLECTION_NAME,data=[data_dict])
  40.     return jsonify({"message": "Image uploaded successfully", "id": image_id})
  41. @app.route("/search",methods=["POST"])
  42. def search_image():
  43.     image_file = request.files["image"]
  44.     image_path = os.path.join("static/images","temp_image.jpg")
  45.     image_file.save(image_path)
  46.     image_features = extract_features(image_path)
  47.     data_li = [extract_features(image_path).tolist()]
  48.     search_result = milvus_client.search(
  49.         collection_name=COLLECTION_NAME,
  50.         data=data_li,
  51.         output_fields=["filepath"],
  52.         limit=TOP_K,
  53.         search_params={'metric_type': 'L2', 'params': {}},
  54.     )
  55.     dict_search_result = search_result[0]
  56.     arr_search_result = []
  57.     destination_folder = '/mcm/vectorDB_training/static/images'
  58.     for index,value in enumerate(dict_search_result):
  59.         source_file = value["entity"]["filepath"]
  60.         base_file_name = os.path.basename(source_file)
  61.         destination_file = os.path.join(destination_folder, base_file_name)
  62.         shutil.copy(source_file, destination_file)
  63.         key_file_name = os.path.join("/static/images",base_file_name)
  64.         arr_search_result.append(key_file_name)        
  65.     image_urls = [
  66.             f"{filepath}" for filepath in arr_search_result
  67.         ]
  68.     return jsonify({"image_urls":image_urls})
  69. if __name__=="__main__":
  70.     app.run(host='0.0.0.0',port=5020,debug=True)
复制代码
谁私信我说脸谱不是猫,只是缩略图里没放清楚,睁大go眼瞧瞧:

小网站结构,以及其他杂代码,可以查看以及直接下载:https://www.ituring.com.cn/book/3305

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

曂沅仴駦

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表