常规练手,图片搜刮山寨版。拜读罗云大佬著作,结果只有操纵层的东西可以上上手。
书中是自己写的向量数据库,这边直接用python拼个现成的milvus向量数据库。
1. 创建一个向量数据库以及对应的相应数据表:
- # Milvus Setup Arguments
- COLLECTION_NAME = 'animal_search'
- DIMENSION = 2048
- MILVUS_HOST = "localhost"
- MILVUS_PORT = "19530"
- # Inference Arguments
- BATCH_SIZE = 128
- from pymilvus import connections
- # Connect to the instance
- connections.connect(host=MILVUS_HOST,port=MILVUS_PORT)
- from pymilvus import utility
- # Remove any previous collection with the same name
- if utility.has_collection(COLLECTION_NAME):
- utility.drop_collection(COLLECTION_NAME)
- #创建保存ID、图片文件路径及Embeddings的Collection。
- from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
- fields = [
- FieldSchema(name='id',dtype=DataType.INT64, is_primary=True, auto_id=True),
- FieldSchema(name='filepath', dtype=DataType.VARCHAR,max_length=200),
- FieldSchema(name='image_embedding',dtype=DataType.FLOAT_VECTOR,dim=DIMENSION)
- ]
- schema = CollectionSchema(fields=fields)
- collection = Collection(name=COLLECTION_NAME, schema=schema)
- index_params = {
- 'metric_type':'L2',
- 'index_type': "IVF_FLAT",
- 'params':{'nlist':16384}
- }
- collection.create_index(field_name="image_embedding",index_params=index_params)
- collection.load()
复制代码 2. 写一堆图片进去存着,向量其实就是各种像素间的维度特征,
- # Milvus Setup Arguments
- COLLECTION_NAME = 'animal_search'
- DIMENSION = 2048
- MILVUS_HOST = "localhost"
- MILVUS_PORT = "19530"
- # Inference Arguments
- BATCH_SIZE = 128
- from pymilvus import connections
- # Connect to the instance
- connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)
- import glob
- paths = glob.glob('/mcm/vectorDB_training/animals_db/*',recursive=True)
- #分批预处理数据
- import torch
- # Load the embedding model with the last layer removed
- model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
- model = torch.nn.Sequential(*(list(model.children())[:-1]))
- model.eval()
- from torchvision import transforms
- # Preprocessing for images
- preprocess = transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
- ])
- #插入数据
- from PIL import Image
- from tqdm import tqdm
- # Embed function that embeds the batch and inserts it
- def embed(data):
- from pymilvus import FieldSchema, CollectionSchema, DataType, Collection
- fields = [
- FieldSchema(name='id',dtype=DataType.INT64, is_primary=True, auto_id=True),
- FieldSchema(name='filepath', dtype=DataType.VARCHAR,max_length=200),
- FieldSchema(name='image_embedding',dtype=DataType.FLOAT_VECTOR,dim=DIMENSION)
- ]
- schema = CollectionSchema(fields=fields)
- collection = Collection(name=COLLECTION_NAME, schema=schema)
- with torch.no_grad():
- output = model(torch.stack(data[0])).squeeze()
- collection.insert([data[1],output.tolist()])
- collection.flush()
- data_batch = [[],[]]
- # Read the images into batches for embedding and insertion
- for path in tqdm(paths):
- im = Image.open(path).convert('RGB')
- data_batch[0].append(preprocess(im))
- data_batch[1].append(path)
- if len(data_batch[0]) % BATCH_SIZE == 0:
- embed(data_batch)
- data_batch = [[],[]]
- # Embed and insert the remainder
- if len(data_batch[0]) != 0:
- embed(data_batch)
复制代码 3. 向量化图片的函数要单独拎出来,做搜刮功能的时候用它。
- import torch
- import torchvision.transforms as transforms
- from torchvision.models import resnet50
- from PIL import Image
- def extract_features(image_path):
- # 加载预训练的 ResNet-50 模型
- model = resnet50(pretrained=True)
- model = torch.nn.Sequential(*list(model.children())[:-1]) #移除fc层,不移除,向量最后就是1000层,而不是2048
- model.eval()
- # 图像预处理
- preprocess = transforms.Compose([
- transforms.Resize(256),
- transforms.CenterCrop(224),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
- ])
- # 读取图像
- img = Image.open(image_path)
- img_t = preprocess(img)
- batch_t = torch.unsqueeze(img_t, 0)
- # 提取特征
- with torch.no_grad():
- out = model(batch_t)
- # 将特征向量转换为一维数组并返回
- return out.flatten().numpy()
复制代码 4. 用flask做的界面
- from flask import Flask,request,jsonify
- from flask import render_template
- from image_eb import extract_features
- #from pymilvus import connections
- from pymilvus import MilvusClient
- import logging
- import os
- import shutil
- MILVUS_HOST = "localhost"
- MILVUS_PORT = "19530"
- COLLECTION_NAME = 'animal_search'
- TOP_K = 3
- app = Flask(__name__)
- milvus_client = MilvusClient(uri="http://localhost:19530")
- @app.route("/")
- def index():
- return render_template("index.html")
- @app.route("/upload",methods=["POST"])
- def upload_image():
- image_file = request.files["image"]
- image_id_str = request.form.get("image_id")
- data = []
- #检查image_id是否存在。
- if not image_id_str:
- return jsonify({"message": "Image ID is required"}),400
- #image id转化为整型
- try:
- image_id = int(image_id_str)
- data.append(image_id)
- except ValueError:
- return jsonify({"message": "Invalid image ID. It must be an integer"}),400
- filename = image_file.filename
- image_path = os.path.join("static/images",image_id_str)
- image_file.save(image_path)
- image_features = extract_features(image_path)
- data.append(image_features)
- data_dict = dict(filepath=image_path,image_embedding=data[1])
- #更新数据库中记录
- milvus_client.insert(collection_name=COLLECTION_NAME,data=[data_dict])
- return jsonify({"message": "Image uploaded successfully", "id": image_id})
- @app.route("/search",methods=["POST"])
- def search_image():
- image_file = request.files["image"]
- image_path = os.path.join("static/images","temp_image.jpg")
- image_file.save(image_path)
- image_features = extract_features(image_path)
- data_li = [extract_features(image_path).tolist()]
- search_result = milvus_client.search(
- collection_name=COLLECTION_NAME,
- data=data_li,
- output_fields=["filepath"],
- limit=TOP_K,
- search_params={'metric_type': 'L2', 'params': {}},
- )
- dict_search_result = search_result[0]
- arr_search_result = []
- destination_folder = '/mcm/vectorDB_training/static/images'
- for index,value in enumerate(dict_search_result):
- source_file = value["entity"]["filepath"]
- base_file_name = os.path.basename(source_file)
- destination_file = os.path.join(destination_folder, base_file_name)
- shutil.copy(source_file, destination_file)
- key_file_name = os.path.join("/static/images",base_file_name)
- arr_search_result.append(key_file_name)
- image_urls = [
- f"{filepath}" for filepath in arr_search_result
- ]
- return jsonify({"image_urls":image_urls})
- if __name__=="__main__":
- app.run(host='0.0.0.0',port=5020,debug=True)
复制代码 谁私信我说脸谱不是猫,只是缩略图里没放清楚,睁大go眼瞧瞧:
小网站结构,以及其他杂代码,可以查看以及直接下载:https://www.ituring.com.cn/book/3305
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |