向量数据库Milvus字符串查询

十念  金牌会员 | 2024-7-31 18:03:07 | 显示全部楼层 | 阅读模式
打印 上一主题 下一主题

主题 567|帖子 567|积分 1701

        由于项目需要,用到了向量数据库Milvus,刚开始都没有遇到问题,直到一个表的主键是字符串(VARCHAR),在查询时刚好要以该主键作为查询条件,此时会出现异常,特此记录一下。
        记取,字符串查询,构建表达式时要加上单引号,比如下面的'{face_id}',实在face_id本来就是一个字符串范例了,假如不加会出现如下的异常:
        # pymilvus.exceptions.MilvusException: <MilvusException: (code=65535, message=cannot parse expression: face_id == 2_0, error: invalid expression: face_id == 2_0)>
  详细看下面的代码(milvus_demo.py),其中exists()函数中构建查询表达式时做了特别处理:
  1. from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility, Partition
  2. import time
  3. from datetime import datetime
  4. from typing import List
  5. #用于测试字符串查询的demo
  6. # MILVUS向量数据库地址
  7. MILVUS_HOST_ONLINE = '127.0.0.1'
  8. MILVUS_PORT = 19530
  9. # 检索时返回的匹配内容条数
  10. VECTOR_SEARCH_TOP_K = 100
  11. class MilvusAvatar:
  12.     # table_name 表名
  13.     # partition_names  分区名,使用默认即可
  14.     def __init__(self, mode, table_name, *, partition_names=["default"], threshold=1.1, client_timeout=3):
  15.         self.table_name = table_name
  16.         self.partition_names = partition_names
  17.         
  18.         self.host = MILVUS_HOST_ONLINE
  19.         self.port = MILVUS_PORT
  20.         self.client_timeout = client_timeout
  21.         self.threshold = threshold
  22.         self.sess: Collection = None
  23.         self.partitions: List[Partition] = []
  24.         self.top_k = VECTOR_SEARCH_TOP_K
  25.         self.search_params = {"metric_type": "L2", "params": {"nprobe": 256}}
  26.         self.create_params = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 2048}}
  27.                
  28.         self.init()
  29.     @property
  30.     def fields(self):
  31.         fields = [
  32.             FieldSchema(name='face_id', dtype=DataType.VARCHAR, max_length=640, is_primary=True, auto_id = False),
  33.             FieldSchema(name='media_id', dtype=DataType.INT64),
  34.             FieldSchema(name='file_path', dtype=DataType.VARCHAR, max_length=640),  #原图片保存路径
  35.             FieldSchema(name='name', dtype=DataType.VARCHAR, max_length=640),  #姓名
  36.             FieldSchema(name='count', dtype=DataType.INT64),  #数量
  37.             FieldSchema(name='save_path', dtype=DataType.VARCHAR, max_length=640),  #现保存的绝对路径,包含文件名
  38.             FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, dim=512)
  39.         ]
  40.         return fields
  41.     @property
  42.     def output_fields(self):
  43.         return ['face_id','media_id', 'file_path', 'name', 'count', 'save_path','embedding']
  44.     def init(self):
  45.         try:
  46.             connections.connect(host=self.host, port=self.port)  # timeout=3 [cannot set]
  47.             if utility.has_collection(self.table_name):
  48.                 self.sess = Collection(self.table_name)
  49.                 print(f'collection {self.table_name} exists')
  50.             else:
  51.                 schema = CollectionSchema(self.fields)
  52.                 print(f'create collection {self.table_name} {schema}')
  53.                 self.sess = Collection(self.table_name, schema)
  54.                 self.sess.create_index(field_name="embedding", index_params=self.create_params)
  55.             for index in self.partition_names:
  56.                 if not self.sess.has_partition(index):
  57.                     self.sess.create_partition(index)
  58.             self.partitions = [Partition(self.sess, index) for index in self.partition_names]
  59.             print('partitions: %s', self.partition_names)
  60.             self.sess.load()
  61.         except Exception as e:
  62.             print(e)
  63.         
  64.     def query_expr_sync(self, expr, output_fields=None, client_timeout=None):
  65.         if client_timeout is None:
  66.             client_timeout = self.client_timeout
  67.         if not output_fields:
  68.             output_fields = self.output_fields
  69.         print(f"MilvusAvatar query_expr_sync:{expr},output_fields:{output_fields}")
  70.         print(f"MilvusAvatar num_entities:{self.sess.num_entities}")
  71.         if self.sess.num_entities == 0:
  72.             return []
  73.             
  74.         return  self.sess.query(partition_names=self.partition_names,
  75.                                 output_fields=output_fields,
  76.                                 expr=expr,
  77.                                 _async= False,
  78.                                 offset=0,
  79.                                 limit=1000)
  80.         
  81.     # emb 为一个人脸特征向量
  82.     def insert_avatar_sync(self, face_id, media_id, file_path, name, save_path, embedding):
  83.         print(f'now insert_avatar {file_path}')
  84.         print(f'now insert_avatar {file_path}')
  85.                        
  86.         data = [[] for _ in range(len(self.sess.schema))]
  87.         data[0].append(face_id)
  88.         data[1].append(media_id)
  89.         data[2].append(file_path)
  90.         data[3].append(name)
  91.         data[4].append(1)
  92.         data[5].append(save_path)
  93.         data[6].append(embedding)
  94.         # 执行插入操作
  95.         try:
  96.             print('Inserting into Milvus...')
  97.             self.partitions[0].insert(data=data)
  98.             print(f'{file_path}')
  99.             
  100.             print(f"MilvusAvatar insert_avatar num_entities:{self.sess.num_entities}")
  101.         except Exception as e:
  102.             print(f'Milvus insert media_id:{media_id}, file_path:{file_path} failed: {e}')
  103.             print(f'Milvus insert media_id:{media_id}, file_path:{file_path} failed: {e}')
  104.             return False
  105.         return True   
  106.         
  107.   
  108.     # embs是一个数组
  109.     def search_emb_sync(self, embs, expr='', top_k=None, client_timeout=None):
  110.         if self.sess is None:
  111.             return None
  112.    
  113.         if not top_k:
  114.             top_k = self.top_k
  115.         milvus_records = self.sess.search(data=embs, partition_names=self.kb_ids, anns_field="embedding",
  116.                                           param=self.search_params, limit=top_k,
  117.                                           output_fields=self.output_fields, expr=expr, timeout=client_timeout)
  118.         print(f"milvus_records:{milvus_records}")
  119.         return milvus_records   
  120.         
  121.       
  122.     def exists(self,face_id):
  123.         print(f"exists:{face_id},{type(face_id)}")
  124.         # 记住,字符串查询,构建表达式时要加上单引号,比如下面的'{face_id}',其实face_id本来就是一个字符串类型了,如果不加会出现如下的异常:
  125.         # pymilvus.exceptions.MilvusException: <MilvusException: (code=65535, message=cannot parse expression: face_id == 2_0, error: invalid expression: face_id == 2_0)>
  126.         res = self.query_expr_sync(expr=f"face_id == '{face_id}'", output_fields=self.output_fields)
  127.         #print(f"exists:{res},{len(res)}")
  128.         if len(res) > 0:
  129.             return True
  130.         
  131.         return False
  132.    
  133.    
  134.     # 修改照片数   
  135.     def add_count(self, face_id):
  136.         res = self.query_expr_sync(expr=f"face_id == '{face_id}'", output_fields=self.output_fields)
  137.         self.sess.delete(expr=f"face_id == '{face_id}'")
  138.         for result in res:
  139.             media_id = result['media_id']
  140.             file_path = result['file_path']
  141.             name = result['name']
  142.             count = int(result['count'])
  143.             save_path = result['save_path']
  144.             embedding = result['embedding']
  145.             
  146.             data = [[] for _ in range(len(self.sess.schema))]
  147.             data[0].append(face_id)
  148.             data[1].append(media_id)
  149.             data[2].append(file_path)
  150.             data[3].append(name)
  151.             data[4].append(count + 1)
  152.             data[5].append(save_path)
  153.             data[6].append(embedding)   
  154.             print(f"add_count face_id:{face_id},file_path:{file_path}, count:{count}")
  155.             
  156.             # 执行插入操作
  157.             try:
  158.                 print('Inserting into Milvus...')
  159.                 self.partitions[0].insert(data=data)
  160.             except Exception as e:
  161.                 print(f'Milvus insert media_id:{media_id}, file_path:{file_path} failed: {e}')
  162.                 return False               
  163.         
  164.     def delete_collection(self):
  165.         print("delete_collection")
  166.         self.sess.release()
  167.         utility.drop_collection(self.table_name)
  168.     def delete_partition(self, partition_name):
  169.         print("delete_partition")
  170.         part = Partition(self.sess, partition_name)
  171.         part.release()
  172.         self.sess.drop_partition(partition_name)
  173.         
  174.     def query_all(self,limit=None):
  175.         res = self.sess.query(partition_names = self.partition_names,
  176.                                 output_fields = ["face_id","media_id", "name", "count", "save_path"],
  177.                                 expr= f"face_id != ''",
  178.                                 _async = False,
  179.                                 offset = 0,
  180.                                 limit = None)
  181.                                 
  182.         print(res)
  183.         return res
  184. if __name__ == "__main__":
  185.     milvus_avatar= MilvusAvatar("local", "avatar", partition_names=["avatar"])
  186.     media_id = 2
  187.     index = 0
  188.     face_id = f"{media_id}_{index}"
  189.     file_path = "/home/data/bbh.jpg"
  190.     save_path = "/home/data/bbh_avatar.jpg"
  191.     embedding = [i/1000 for i in range(512)]
  192.     milvus_avatar.insert_avatar_sync(face_id, media_id, file_path, "bbh", save_path, embedding)
  193.     #result = milvus_avatar.query_all()
  194.     #print(result)
  195.     print(milvus_avatar.exists(face_id))
  196.    
复制代码
实行:python milvus_demo.py
假如是针对非字符串字段进行查询,则无需做上面的特别处理。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

十念

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表