前情提要
果蔬识别系统性能优化之路(四)
剩下标题
- 新建store_feature表,关联storeCode和featureId表,对数据库举行规范化,创建一个新的表来映射storeCode与feature的关系,从而可以利用简朴的WHERE条件来充实利用索引
- 实现对特征向量ivf的增删改查
办理方案
新建storeFeature表
- import { Entity, PrimaryGeneratedColumn, Column, OneToMany } from 'typeorm';
- import { StoreFeature } from '../../feature/entities/store-feature.entity';
- @Entity()
- export class Store {
- @PrimaryGeneratedColumn()
- id: number;
- @Column({ unique: true })
- storeCode: string;
- @Column({ nullable: true })
- storeName: string;
- @OneToMany(() => StoreFeature, (storeFeature) => storeFeature.store)
- storeFeatures: StoreFeature[];
- }
复制代码- import { Entity, ManyToOne, JoinColumn, PrimaryGeneratedColumn } from 'typeorm';
- import { Store } from '../../store/entities/store.entity';
- import { Feature } from './feature.entity';
- @Entity()
- export class StoreFeature {
- @PrimaryGeneratedColumn()
- id: number;
- @ManyToOne(() => Store, { onDelete: 'CASCADE' })
- @JoinColumn({ name: 'storeCode', referencedColumnName: 'storeCode' })
- store: Store;
- @ManyToOne(() => Feature, { onDelete: 'CASCADE' })
- @JoinColumn({ name: 'featureId', referencedColumnName: 'id' })
- feature: Feature;
- }
复制代码 storeFeature表关联store表和feature表
- import { Injectable } from '@nestjs/common';
- import { CreateFeatureDto } from './dto/create-feature.dto';
- import { Feature } from './entities/feature.entity';
- import { InjectRepository } from '@nestjs/typeorm';
- import { Repository, In } from 'typeorm';
- import { RedisService } from '../redis/redis.service';
- import { HttpService } from '@nestjs/axios';
- import { firstValueFrom } from 'rxjs';
- import * as FormData from 'form-data';
- import { Img } from '../img/entities/img.entity';
- import { Store } from '../store/entities/store.entity';
- import { StoreFeature } from './entities/store-feature.entity';
- @Injectable()
- export class FeatureService {
- constructor(
- @InjectRepository(Feature)
- private readonly featureRepository: Repository<Feature>,
- @InjectRepository(Img)
- private readonly imgRepository: Repository<Img>,
- @InjectRepository(Store)
- private readonly storeRepository: Repository<Store>,
- @InjectRepository(StoreFeature)
- private readonly storeFeatureRepository: Repository<StoreFeature>,
- private readonly httpService: HttpService,
- private readonly redisService: RedisService,
- ) {
- }
- /**
- * 创建
- * @param file
- * @param createFeatureDto
- * @param needSync //是否需要同步redis,默认为true
- */
- async create(file: Express.Multer.File, createFeatureDto: CreateFeatureDto, needSync: boolean = true): Promise<Feature> {
- const img = this.imgRepository.create({
- img: file.buffer,
- });
- await this.imgRepository.save(img);
- const [feature, store] = await Promise.all([
- new Promise(async (resolve) => {
- const feature: Feature = this.featureRepository.create({
- ...createFeatureDto,
- imgId: img.id,
- });
- await this.featureRepository.save(feature);
- resolve(feature);
- }),
- new Promise(async (resolve) => {
- let store = await this.storeRepository.findOne({ where: { storeCode: createFeatureDto.storeCode } });
- if (!store) {
- store = this.storeRepository.create({
- storeCode: createFeatureDto.storeCode,
- storeName: createFeatureDto.storeName,
- });
- await this.storeRepository.save(store);
- }
- resolve(store);
- }),
- ]);
- const storeFeature = this.storeFeatureRepository.create({
- feature,
- store,
- });
- await this.storeFeatureRepository.save(storeFeature);
- needSync && await this.syncRedis(createFeatureDto.storeCode);
- return feature as Feature;
- }
- /**
- * 同步redis
- * @param storeCode
- */
- async syncRedis(storeCode: string) {
- const url = 'http://localhost:5000/sync'; // Python 服务的 URL
- const s = Date.now();
- const response = await firstValueFrom(this.httpService.post(url, { storeCode }));
- const { ids } = response.data;
- await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(ids));
- const e = Date.now();
- console.log(`门店:${storeCode},同步redis耗时:${e - s}ms`);
- }
- /**
- * 查询所有
- * @param storeCode
- * @param selectP
- */
- async findAll(storeCode: string, selectP?: string[]) {
- return await this.featureRepository
- .createQueryBuilder('feature')
- .select(selectP)
- .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
- .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
- .where('store.storeCode = :storeCode', { storeCode })
- .getMany();
- }
- /**
- * 查询特性及其关联的图像
- * @param storeCode
- */
- async findAllWithImage(storeCode: string): Promise<Feature[]> {
- return await this.featureRepository.createQueryBuilder('feature')
- .leftJoinAndSelect('feature.img', 'img')
- .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
- .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
- .where('store.storeCode = :storeCode', { storeCode })
- .getMany();
- }
- /**
- * 删除门店所有数据
- * @param storeCode
- */
- async removeAll(storeCode: string): Promise<void> {
- const store = await this.storeRepository.findOne({ where: { storeCode }, relations: ['storeFeatures'] });
- if (!store) {
- return;
- }
- // 批量删除 storeFeatures 和 store
- if (store.storeFeatures.length > 0) {
- await this.storeFeatureRepository
- .query('DELETE FROM store_feature WHERE id IN (?)', [store.storeFeatures.map(sf => sf.id)]);
- }
- await this.storeRepository.remove(store); // 删除 store
- const unreferencedFeatures = await this.featureRepository
- .createQueryBuilder('feature')
- .leftJoinAndSelect('feature.img', 'img')
- .leftJoin('feature.storeFeatures', 'storeFeature')
- .where('storeFeature.id IS NULL') // 这里的条件确保我们只选择那些没有其他引用的 feature
- .getMany();
- // 批量删除未引用的 features
- if (unreferencedFeatures.length > 0) {
- for (const feature of unreferencedFeatures) {
- await this.remove(feature);
- }
- }
- await this.redisService.del(`${storeCode}-featureDatabase`);
- await this.syncRedis(storeCode);
- }
- /**
- * 预测
- * @param file
- * @param num
- * @param storeCode
- * @param justPredict
- * @param needList
- */
- async predict(
- file: Express.Multer.File,
- num: string = '5',
- storeCode: string,
- justPredict: string = 'false',
- needList: boolean = false,
- ) {
- const PYTHON_SERVICE_URL = 'http://localhost:5000/predict'; // Python service URL
- const REDIS_KEY_PREFIX = '-featureDatabase';
- const startTime = Date.now();
- const numInt = parseInt(num);
- const isJustPredict = justPredict === 'true';
- try {
- // Prepare form data
- const formData = new FormData();
- formData.append('file', file.buffer, file.originalname);
- formData.append('storeCode', storeCode);
- formData.append('justPredict', justPredict);
- // Send request to Python service
- const response = await firstValueFrom(this.httpService.post(PYTHON_SERVICE_URL, formData));
- const { features, index, predictTime } = response.data;
- if (isJustPredict) {
- return this.buildResponse([], features, predictTime, startTime, numInt);
- }
- // Retrieve feature database from Redis
- const featureDatabaseStr = await this.redisService.get(`${storeCode}${REDIS_KEY_PREFIX}`);
- if (!featureDatabaseStr) {
- return this.buildResponse([], features, predictTime, startTime, numInt);
- }
- // Parse the Redis result and filter the IDs
- const featureDatabase = JSON.parse(featureDatabaseStr);
- const ids = index
- .map((idx: number) => featureDatabase[idx]);
- if (!ids.length) {
- return this.buildResponse([], features, predictTime, startTime, numInt);
- }
- // Query for features in the database
- const featureList = await this.featureRepository.createQueryBuilder('feature')
- .where('feature.id IN (:...ids)', { ids })
- .orderBy(`FIELD(feature.id, ${ids.map((id: any) => `'${id}'`).join(', ')})`, 'ASC')
- .getMany();
- // Filter to ensure unique labels
- const uniqueList = this.filterUniqueFeatures(featureList, numInt);
- const result = this.buildResponse(uniqueList, features, predictTime, startTime, numInt);
- return needList ? { ...result, featureList: featureList.map(({ features, ...rest }) => rest) } : result;
- } catch (error) {
- throw new Error(`Prediction failed: ${error.message}`);
- }
- }
- private filterUniqueFeatures(featureList: any[], limit: number) {
- const uniqueList = [];
- for (const feature of featureList) {
- if (!uniqueList.some(f => f.label === feature.label)) {
- uniqueList.push(feature);
- }
- if (uniqueList.length === limit) break;
- }
- return uniqueList;
- }
- private buildResponse(list: any[], features: any, predictTime: string, startTime: number, num: number) {
- const totalTime = `${Date.now() - startTime}ms`;
- return {
- predictTime,
- [`top${num}`]: list.map(({ features, ...rest }) => rest),
- features,
- totalTime,
- };
- }
- /**
- * 计算余弦相似度
- * @param vecA
- * @param vecB
- */
- cosineSimilarity(vecA: number[], vecB: number[]): number {
- if (vecA.length !== vecB.length) {
- throw new Error('Vectors must be of the same length');
- }
- const dotProduct = vecA.reduce((sum, value, index) => sum + value * vecB[index], 0);
- const magnitudeA = Math.sqrt(vecA.reduce((sum, value) => sum + value * value, 0));
- const magnitudeB = Math.sqrt(vecB.reduce((sum, value) => sum + value * value, 0));
- return dotProduct / (magnitudeA * magnitudeB);
- }
- /**
- * 查找相似
- * @param inputFeatures
- * @param num
- * @param storeCode
- */
- async findTopNSimilar(inputFeatures: number[], num: number, storeCode: string): Promise<{
- label: string;
- similarity: number
- }[]> {
- const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
- if (!featureDatabaseStr) {
- return [];
- }
- const featureDatabase = JSON.parse(featureDatabaseStr);
- const similarities = featureDatabase.map(({ features, label }) => {
- let similarity = 0;
- if (features) {
- similarity = this.cosineSimilarity(inputFeatures, features);
- }
- return { label: label as string, similarity: similarity as number };
- });
- similarities.sort((a: { similarity: number; }, b: { similarity: number; }) => b.similarity - a.similarity);
- const uniqueLabels = new Set<string>();
- const topNUnique: { label: string; similarity: number; }[] = [];
- for (const item of similarities) {
- if (!uniqueLabels.has(item.label as string)) {
- uniqueLabels.add(item.label);
- item.similarity = Math.round(item.similarity * 100) / 100;
- topNUnique.push(item);
- if (topNUnique.length === num) break;
- }
- }
- return topNUnique;
- }
- /**
- * 根据名称查询
- * @param label
- * @param storeCode
- */
- async getByName(label: string, storeCode: string): Promise<Feature[]> {
- return await this.featureRepository
- .createQueryBuilder('feature')
- .leftJoinAndSelect('feature.img', 'img')
- .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
- .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
- .where('store.storeCode = :storeCode', { storeCode })
- .andWhere('feature.label = :label', { label })
- .getMany();
- }
- /**
- * 根据名称向量个数查询
- * @param label
- * @param storeCode
- */
- async getCountByLabel(label: string, storeCode: string): Promise<number> {
- return await this.featureRepository
- .createQueryBuilder('feature')
- .leftJoinAndSelect('feature.img', 'img')
- .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
- .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
- .where('store.storeCode = :storeCode', { storeCode })
- .andWhere('feature.label = :label', { label })
- .getCount();
- }
- /**
- * 批量学习
- * @param files
- * @param createFeatureDto
- */
- async batchStudy(files: Express.Multer.File[], createFeatureDto: CreateFeatureDto) {
- const list = [];
- for (const file of files) {
- try {
- const { features: f } = await this.predict(file, '5', createFeatureDto.storeCode, 'true');
- const feature = await this.create(file, {
- ...createFeatureDto,
- features: f,
- }, false);
- // 创建一个副本,不包含 `features` 属性
- const { features, ...featureWithoutFeatures } = feature;
- // 将不包含 `features` 属性的对象推送到数组中
- list.push(featureWithoutFeatures);
- } catch (e) {
- console.error(e);
- }
- }
- await this.syncRedis(createFeatureDto.storeCode);
- return list;
- }
- /**
- * 删除门店的特征值数据
- * @param feature
- */
- async remove(feature: Feature) {
- await this.featureRepository.remove(feature);
- await this.imgRepository.remove(feature.img);
- }
- /**
- * 批量删除
- * @param ids
- * @param storeCode
- */
- async batchRemove(ids: string, storeCode: string) {
- const list = ids.split(',').map(id => +id);
- // 批量查询所有相关的 Feature
- const features = await this.featureRepository.find({
- where: { id: In(list) },
- relations: ['img', 'storeFeatures'],
- });
- for (const feature of features) {
- feature && await this.remove(feature);
- await this.storeFeatureRepository.remove(feature.storeFeatures);
- }
- await this.syncRedis(storeCode);
- }
- /**
- * 导入数据
- * @param storeCode
- * @param sourceStoreCode
- * @param storeName
- */
- async importData(storeCode: string, sourceStoreCode?: string, storeName?: string) {
- let storeFeatures = [];
- // 第一步:查询指定 storeCode 关联的所有 featureId
- const storeFeatureIds = await this.storeFeatureRepository
- .createQueryBuilder('storeFeature')
- .select('storeFeature.featureId')
- .where('storeFeature.storeCode = :storeCode', { storeCode })
- .getRawMany();
- // 提取出 featureId 列表
- const featureIdsToExclude = storeFeatureIds.map(row => row.featureId);
- let distinctFeatureIds = [];
- if (featureIdsToExclude.length === 0) {
- distinctFeatureIds = await this.storeFeatureRepository
- .createQueryBuilder('storeFeature')
- .select('DISTINCT storeFeature.featureId') // 确保 featureId 唯一
- .getRawMany();
- } else {
- // 第二步:排除这些 featureId,并确保 featureId 唯一
- distinctFeatureIds = await this.storeFeatureRepository
- .createQueryBuilder('storeFeature')
- .select('DISTINCT storeFeature.featureId') // 确保 featureId 唯一
- .where('storeFeature.featureId NOT IN (:...featureIdsToExclude)', { featureIdsToExclude }) // 排除 featureId
- .getRawMany();
- }
- const featureIds = distinctFeatureIds.map(record => record.featureId);
- if (!sourceStoreCode) {
- storeFeatures = await this.featureRepository
- .createQueryBuilder('feature')
- .leftJoinAndSelect('feature.img', 'img')
- .whereInIds(featureIds)
- .getMany();
- } else {
- storeFeatures = await this.featureRepository
- .createQueryBuilder('feature')
- .leftJoinAndSelect('feature.img', 'img')
- .innerJoin('feature.storeFeatures', 'storeFeatures')
- .whereInIds(featureIds)
- .andWhere('storeFeatures.storeCode = :storeCode', { storeCode: sourceStoreCode }) // 使用参数化查询
- .getMany();
- }
- let targetStore = await this.storeRepository.findOne({ where: { storeCode: storeCode } });
- if (!targetStore) {
- targetStore = this.storeRepository.create({
- storeCode: storeCode,
- storeName: storeName,
- });
- await this.storeRepository.save(targetStore);
- }
- // Create new StoreFeature records for the target storeCode
- const newStoreFeatures = storeFeatures.map((feature: Feature) => ({
- store: targetStore,
- feature, // Reuse the existing feature
- }));
- // Save new StoreFeature records
- const storeFeatureInstances = this.storeFeatureRepository.create(newStoreFeatures);
- await this.storeFeatureRepository.save(storeFeatureInstances);
- await this.syncRedis(storeCode);
- return `同步完成,共导入${storeFeatures.length}条数据`;
- }
- async init() {
- const distinctStoreCodes = await this.storeRepository
- .createQueryBuilder('store')
- .select('store.storeCode')
- .distinct(true)
- .getRawMany();
- const syncList = [];
- for (const row of distinctStoreCodes) {
- const storeCode = row.store_storeCode;
- syncList.push(this.syncRedis(storeCode));
- }
- await Promise.all(syncList);
- console.log('初始化完成');
- }
- }
复制代码
- 效果:并没有提升多少,但幸亏关系更清楚,为之后的拓展打了底子
实现ivf的动态增删改查
- 结论:ivf无法在不训练只增长的情况下举行新增向量的识别,所以每次新增向量必须重新举行训练和添加
- python端ivf改造
detect.py(识别和同步方法)
- import tensorflow as tf
- from tensorflow.keras.applications import MobileNetV2
- import numpy as np
- import time
- import gc
- from ivf import IVFPQ
- from feature import get_feature_by_store_code
- import orjson
- from concurrent.futures import ThreadPoolExecutor
- # 加载预训练的 MobileNetV2 模型,不包含顶部的分类层
- model = MobileNetV2(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')
- class MainDetect:
- # 初始化
- def __init__(self):
- super().__init__()
- # 模型初始化
- self.image_id = None
- self.image_features = None
- self.model = tf.keras.models.load_model("models/custom/my-model.h5")
- self.ivfObj = {}
- def classify_image(self, image_data, store_code, just_predict):
- # Load and preprocess image
- img = tf.image.decode_image(image_data, channels=3)
- img = tf.image.resize(img, [224, 224])
- img = tf.expand_dims(img, axis=0) # Add batch dimension
- # Run model prediction
- start_time = time.time()
- outputs = model.predict(img)
- # outputs = self.model.predict(outputs)
- # prediction = tf.divide(outputs, tf.norm(outputs))
- i = []
- if just_predict == "false":
- if store_code + '-featureDatabase' in self.ivfObj:
- i = self.ivfObj[store_code + '-featureDatabase'].search(outputs)
- i = i.flatten().tolist()
- end_time = time.time()
- # Calculate elapsed time
- elapsed_time = end_time - start_time
- # Flatten the outputs and return them
- # output_data = prediction.numpy().flatten().tolist()
- output_data = outputs.flatten().tolist()
- # Force garbage collection to free up memory
- del img, outputs, end_time, start_time # Ensure variables are deleted
- gc.collect()
- return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}
- def sync(self, store_code):
- if store_code + '-featureDatabase' in self.ivfObj:
- del self.ivfObj[store_code + '-featureDatabase']
- data = get_feature_by_store_code(store_code)
- if len(data) == 0:
- return []
- else:
- def parse_features(item):
- return orjson.loads(item['features'])
- with ThreadPoolExecutor() as executor:
- features_list = list(executor.map(parse_features, data))
- # 提取所有特征并转换为 NumPy 数组
- features = np.array(features_list, dtype=np.float32)
- self.ivfObj[store_code + '-featureDatabase'] = IVFPQ(features)
- ids = [item['id'] for item in data]
- return ids
复制代码 ivf.py(ivf构造)
- import faiss
- import numpy as np
- num_threads = 8
- faiss.omp_set_num_threads(num_threads)
- class IVFPQ:
- def __init__(self, features, nlist=100, m=16, n_bits=8):
- d = features.shape[1]
- # 创建量化器
- quantizer = faiss.IndexFlatL2(d) # 使用L2距离进行量化
- self.index = faiss.IndexIVFFlat(quantizer, d, nlist)
- # self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)
- # 训练索引
- count = 3900
- if features.size >= count * d:
- self.index.train(features)
- if features.size > 1000 * d:
- batch_size = 1000 # 每次处理1000个特征
- for i in range(0, len(features), batch_size):
- self.index.add(features[i:i + batch_size])
- else:
- self.index.add(features)
- else:
- points = int(count - features.size / d)
- np.random.seed(points)
- xb = np.random.random((points, d)).astype('float32') # 模拟数据库中的特征向量
- combined_features = np.vstack((features, xb)) # Stack them vertically
- # 训练索引
- self.index.train(combined_features)
- self.index.add(combined_features) # 将特征向量添加到索引中
- def search(self, xq, k=100):
- d, i = self.index.search(xq, k)
- return i
- def add(self, xb):
- self.index.add(xb)
- def train(self, xb):
- self.index.train(xb)
- def sync(self, features):
- for i in range(len(features)):
- self.add(features[i])
复制代码 结语
这个项目优化到这差不多告一段落了,后续还有啥优化点会继续跟进,稍后会把整个架构图和功能点都梳理一遍
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |