ToB企服应用市场:ToB评测及商务社交产业平台

标题: 以图搜图功能实现(ES/Milvus) [打印本页]

作者: 道家人    时间: 2024-11-5 20:20
标题: 以图搜图功能实现(ES/Milvus)
思绪方案

思绪

需要将非结构化数据→转为结构化→再完成搜索。将非结构化数据,转化为结构化的多维向量,用这些向量标识实体和实体间的关系。再盘算向量之间距离,通常情况下,距离越近、相似度越高,召回相似度最高的TOP结果,完成检索。

方案

给定一组查询图片和数据库图片。我们对数据库图片实行以图搜图操纵,在image embeddings(将图片数据转换为固定巨细的特征表示——矢量)上获取前k个最相似的数据库中的图片。
将接纳以下两种方法实行以图搜图功能:
Milvus 向量数据库

Milvus 在非结构化数据处理中的应用非常强大。Milvus 向量相似度检索引擎可以兼容各种深度学习平台,搜索十亿向量仅毫秒响应。
ElasticSearch 向量数据库(resnet50模型)

方案一(Milvus1.0)

功能介绍

以图搜图,涉及两大功能:1、提取图像特征向量。2、相似向量检索。
通过盘算特征向量来分析非结构化数据。使用ResNet-50举行特征提取,构建反向图像搜索体系。
情况搭建

yaml文件配置

  1. version: 0.5
  2. cluster:
  3.   enable: false
  4.   role: rw
  5. general:
  6.   timezone: UTC+8
  7.   meta_uri: sqlite://:@:/
  8. network:
  9.   bind.address: 0.0.0.0
  10.   bind.port: 19530
  11.   http.enable: true
  12.   http.port: 19121
  13. storage:
  14.   path: /var/lib/milvus
  15.   auto_flush_interval: 1
  16. wal:
  17.   enable: true
  18.   recovery_error_ignore: false
  19.   buffer_size: 256MB
  20.   path: /var/lib/milvus/wal
  21. cache:
  22.   cache_size: 256MB
  23.   insert_buffer_size: 256MB
  24.   preload_collection:
  25. gpu:
  26.   enable: false
  27.   cache_size: 256MB
  28.   gpu_search_threshold: 1000
  29.   search_devices:
  30.     - gpu0
  31.   build_index_devices:
  32.     - gpu0
  33. fpga:
  34.    enable: false
  35.    search_devices:
  36.      - fpga0
  37. logs:
  38.   level: debug
  39.   trace.enable: true
  40.   path: /var/lib/milvus/logs
  41.   max_log_file_size: 1073741824
  42.   log_rotate_num: 0
  43.   log_to_stdout: false
  44.   log_to_file: true
  45. metric:
  46.   enable: false
  47.   address: 127.0.0.1
  48.   port: 9091
复制代码
docker摆设

  1. docker run -d --name milvus_1 \
  2. -p 19530:19530 \
  3. -p 19121:19121 \
  4. -v /root/milvus/db:/var/lib/milvus/db \
  5. -v /root/milvus/conf:/var/lib/milvus/conf \
  6. -v /root/milvus/logs:/var/lib/milvus/logs \
  7. -v /root/milvus/wal:/var/lib/milvus/wal \
  8. milvusdb/milvus:1.0.0-cpu-d030521-1ea92e
  9. 2)
  10. docker run -d --name image_search \
  11. -v /root/milvus/pic:/tmp/pic1 \
  12. -p 35000:5000 \
  13. -e "DATA_PATH=/tmp/images-data" \
  14. -e "MILVUS_HOST=你的服务器ip地址" \
  15. milvusbootcamp/pic-search-webserver:1.0
  16. 3)
  17. docker run --name milvus_image_search_web -d --rm -p 8001:80 \
  18. -e API_URL=http://你的服务器ip地址:35000 \
  19. milvusbootcamp/pic-search-webclient:1.0
复制代码
结果图


测试

原图举行验证搜索


截图举行验证搜索


不相关图片举行验证


升级方案(Milvus2.X)

问题

在调用模型时无法连接至hugging face 无法将图片转为向量
方案二(ElasticSearch + ResNet-50模型)

功能介绍

以图搜图,涉及两大功能:1、提取图像特征向量。2、相似向量检索。
第一个功能通过pytorch下载保存resnet50模型并在java端借助djl调用实现,第二个功能通过elasticsearch7.12.2的dense_vector、cosineSimilarity实现。
情况摆设(通过编写pytorch模型并在java端借助djl调用实现)

提取图像特征下载模型到本地(resnet50模型)

  1. import torch
  2. import torch.nn as nn
  3. import torchvision.models as models
  4. class ImageFeatureExtractor(nn.Module):
  5.     def __init__(self):
  6.         super(ImageFeatureExtractor, self).__init__()
  7.         self.resnet = models.resnet50(pretrained=True)
  8.         #最终输出维度1024的向量,下文elastic search要设置dims为1024
  9.         self.resnet.fc = nn.Linear(2048, 1024)
  10.     def forward(self, x):
  11.         x = self.resnet(x)
  12.         return x
  13. if __name__ == '__main__':
  14.     model = ImageFeatureExtractor()
  15.     model.eval()
  16.     #根据模型随便创建一个输入
  17.     input = torch.rand([1, 3, 224, 224])
  18.     output = model(input)
  19.     #以这种方式保存
  20.     script = torch.jit.trace(model, input)
  21.     script.save("model.pt")
复制代码
保存好的model.pt文件放入java项目的resources中。
摆设elasticsearch kibana

  1. es版本:7.6.2
  2. docker部署:
  3. docker run -p 9200:9200 -p 9300:9300 \
  4. --privileged=true --name es7.6.2 \
  5. -e "discovery.type=single-node" \
  6. -e ES_JAVA_OPTS="-Xms512m -Xmx1024m" \
  7. -e "http.max_content_length=500mb" \
  8. -v /root/mydata/plugins:/usr/share/elasticsearch/plugins \
  9. -v /root/mydata/data:/usr/share/elasticsearch/data \
  10. -v /root/mydata/logs:/usr/share/elasticsearch/logs \
  11. -d elasticsearch:7.6.2
  12. docker run -d \
  13. --name kibana \
  14. --restart=always \
  15. -p 5601:5601 \
  16. -v /data/kibana/config/kibana.yml:/usr/share/kibana/config/kibana.yml \
  17. kibana:7.6.2
复制代码
创建索引库

  1. PUT /isi
  2. {
  3.   "mappings": {
  4.     "properties": {
  5.       "vector": {
  6.         "type": "dense_vector",
  7.         "dims": 1024
  8.       },
  9.       "url" : {
  10.         "type" : "keyword"
  11.       },
  12.       "user_id": {
  13.           "type": "keyword"
  14.       }
  15.     }
  16.   }
  17. }
复制代码
相似向量上传、检索

创建调用resnet模型 转化格式

  1. public class Test {
  2.   private static final String INDEX = "isi";
  3.   private static final int IMAGE_SIZE = 224;
  4.   private static Model model; //模型
  5.   private static Predictor<Image, float[]> predictor;
  6.   //predictor.predict(input)相当于python中model(input)
  7.   static {
  8.     try {
  9.       model = Model.newInstance("model");
  10.       //这里的model.pt是上面代码展示的那种方式保存的
  11.       model.load(Test.class.getClassLoader().getResourceAsStream("model.pt"));
  12.       Transform resize = new Resize(IMAGE_SIZE);
  13.       Transform toTensor = new ToTensor();
  14.       Transform normalize = new Normalize(new float[]{0.485f, 0.456f, 0.406f}, new float[]{0.229f, 0.224f, 0.225f});
  15.       //Translator处理输入Image转为tensor、输出转为float[]
  16.       Translator<Image, float[]> translator = new Translator<Image, float[]>() {
  17.         @Override
  18.         public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
  19.           NDManager ndManager = ctx.getNDManager();
  20.           System.out.println("input: " + input.getWidth() + ", " + input.getHeight());
  21.           NDArray transform = normalize.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));
  22.           System.out.println(transform.getShape());
  23.           NDList list = new NDList();
  24.           list.add(transform);
  25.           return list;
  26.         }
  27.         @Override
  28.         public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
  29.           return ndList.get(0).toFloatArray();
  30.         }
  31.       };
  32.       predictor = new Predictor<>(model, translator, Device.cpu(), true);
  33.     } catch (Exception e) {
  34.       e.printStackTrace();
  35.     }
  36.   }
  37. }
复制代码
批量上传图片到es

  1. public static void upload() throws Exception {
  2.         RestHighLevelClient client = new RestHighLevelClient(
  3.                 RestClient.builder(new HttpHost("192.168.110.132", 9200, "http")));
  4.         //批量上传请求
  5.         File file = new File("E:\\javacode\\javaes\\src\\main\\resources\\test");
  6.         File[] files = file.listFiles();
  7.         if (files == null) return;
  8.         int batchSize = 1000;
  9.         for (int i = 0; i < files.length; i += batchSize) {
  10.             BulkRequest bulkRequest = new BulkRequest(INDEX);
  11.             for (int j = i; j < i + batchSize && j < files.length; j++) {
  12.                 File listFile = files[j];
  13.                 float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(new FileInputStream(listFile)));
  14.                 Map<String, Object> jsonMap = new HashMap<>();
  15.                 jsonMap.put("url", listFile.getAbsolutePath());
  16.                 jsonMap.put("vector", vector);
  17.                 jsonMap.put("user_id", "user123");
  18.                 IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
  19.                 bulkRequest.add(request);
  20.             }
  21.             client.bulk(bulkRequest, RequestOptions.DEFAULT);
  22.             /*for (File listFile : file.listFiles()) {
  23.             float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(Test2.class.getClassLoader().getResourceAsStream("test/" + listFile.getName())));
  24.             // 构建文档
  25.             Map<String, Object> jsonMap = new HashMap<>();
  26.             jsonMap.put("url", listFile.getAbsolutePath());
  27.             jsonMap.put("vector", vector);
  28.             jsonMap.put("user_id", "user123");
  29.             IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
  30.             bulkRequest.add(request);*/
  31.         }
  32.         client.close();
  33.     }
复制代码
搜索(将图片转为向量与es文档库匹配)

  1. public static List<SearchResult> search(InputStream input) throws Throwable {
  2.         float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));
  3.         System.out.println(Arrays.toString(vector));
  4.         //展示k个结果
  5.         int k = 50;
  6.         // 连接Elasticsearch服务器
  7.         RestHighLevelClient client = new RestHighLevelClient(
  8.                 RestClient.builder(new HttpHost("192.168.110.132", 9200, "http")));
  9.         SearchRequest searchRequest = new SearchRequest(INDEX);
  10.         Script script = new Script(
  11.                 ScriptType.INLINE,
  12.                 "painless",
  13.                 "cosineSimilarity(params.queryVector, doc['vector'])",
  14.                 Collections.singletonMap("queryVector", vector));
  15.         FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders.functionScoreQuery(
  16.                 QueryBuilders.matchAllQuery(),
  17.                 ScoreFunctionBuilders.scriptFunction(script));
  18.         SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
  19.         searchSourceBuilder.query(functionScoreQueryBuilder)
  20.                 .fetchSource(null, "vector") //不返回vector字段,没用还耗时
  21.                 .size(k);
  22.         searchRequest.source(searchSourceBuilder);
  23.         SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);
  24.         SearchHits hits = searchResponse.getHits();
  25.         List<SearchResult> list = new ArrayList<>();
  26.         for (SearchHit hit : hits) {
  27.             // 处理搜索结果
  28.             System.out.println(hit.toString());
  29.             SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());
  30.             list.add(result);
  31.         }
  32.         client.close();
  33.         return list;
  34.     }
复制代码
结果图


测试

原图举行验证搜索


截图举行验证搜索


不相关图片举行验证



免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4