卖不甜枣 发表于 2024-9-15 10:11:18

果蔬识别系统性能优化之路(五)

前情提要

果蔬识别系统性能优化之路(四)
剩下标题


[*]新建store_feature表,关联storeCode和featureId表,对数据库举行规范化,创建一个新的表来映射storeCode与feature的关系,从而可以利用简朴的WHERE条件来充实利用索引
[*]实现对特征向量ivf的增删改查
办理方案

新建storeFeature表


[*]新建store表,storeFeature表
import { Entity, PrimaryGeneratedColumn, Column, OneToMany } from 'typeorm';
import { StoreFeature } from '../../feature/entities/store-feature.entity';

@Entity()
export class Store {
@PrimaryGeneratedColumn()
id: number;

@Column({ unique: true })
storeCode: string;

@Column({ nullable: true })
storeName: string;

@OneToMany(() => StoreFeature, (storeFeature) => storeFeature.store)
storeFeatures: StoreFeature[];
}

import { Entity, ManyToOne, JoinColumn, PrimaryGeneratedColumn } from 'typeorm';
import { Store } from '../../store/entities/store.entity';
import { Feature } from './feature.entity';

@Entity()
export class StoreFeature {
@PrimaryGeneratedColumn()
id: number;

@ManyToOne(() => Store, { onDelete: 'CASCADE' })
@JoinColumn({ name: 'storeCode', referencedColumnName: 'storeCode' })
store: Store;

@ManyToOne(() => Feature, { onDelete: 'CASCADE' })
@JoinColumn({ name: 'featureId', referencedColumnName: 'id' })
feature: Feature;
}

storeFeature表关联store表和feature表

[*]feature.service大改造
import { Injectable } from '@nestjs/common';
import { CreateFeatureDto } from './dto/create-feature.dto';
import { Feature } from './entities/feature.entity';
import { InjectRepository } from '@nestjs/typeorm';
import { Repository, In } from 'typeorm';
import { RedisService } from '../redis/redis.service';
import { HttpService } from '@nestjs/axios';
import { firstValueFrom } from 'rxjs';
import * as FormData from 'form-data';
import { Img } from '../img/entities/img.entity';
import { Store } from '../store/entities/store.entity';
import { StoreFeature } from './entities/store-feature.entity';

@Injectable()
export class FeatureService {
constructor(
    @InjectRepository(Feature)
    private readonly featureRepository: Repository<Feature>,
    @InjectRepository(Img)
    private readonly imgRepository: Repository<Img>,
    @InjectRepository(Store)
    private readonly storeRepository: Repository<Store>,
    @InjectRepository(StoreFeature)
    private readonly storeFeatureRepository: Repository<StoreFeature>,
    private readonly httpService: HttpService,
    private readonly redisService: RedisService,
) {
}

/**
   * 创建
   * @param file
   * @param createFeatureDto
   * @param needSync //是否需要同步redis,默认为true
   */
async create(file: Express.Multer.File, createFeatureDto: CreateFeatureDto, needSync: boolean = true): Promise<Feature> {
    const img = this.imgRepository.create({
      img: file.buffer,
    });
    await this.imgRepository.save(img);
    const = await Promise.all([
      new Promise(async (resolve) => {
      const feature: Feature = this.featureRepository.create({
          ...createFeatureDto,
          imgId: img.id,
      });
      await this.featureRepository.save(feature);
      resolve(feature);
      }),
      new Promise(async (resolve) => {
      let store = await this.storeRepository.findOne({ where: { storeCode: createFeatureDto.storeCode } });
      if (!store) {
          store = this.storeRepository.create({
            storeCode: createFeatureDto.storeCode,
            storeName: createFeatureDto.storeName,
          });
          await this.storeRepository.save(store);
      }
      resolve(store);
      }),
    ]);
    const storeFeature = this.storeFeatureRepository.create({
      feature,
      store,
    });
    await this.storeFeatureRepository.save(storeFeature);
    needSync && await this.syncRedis(createFeatureDto.storeCode);
    return feature as Feature;
}

/**
   * 同步redis
   * @param storeCode
   */
async syncRedis(storeCode: string) {
    const url = 'http://localhost:5000/sync'; // Python 服务的 URL
    const s = Date.now();
    const response = await firstValueFrom(this.httpService.post(url, { storeCode }));
    const { ids } = response.data;
    await this.redisService.set(`${storeCode}-featureDatabase`, JSON.stringify(ids));
    const e = Date.now();
    console.log(`门店:${storeCode},同步redis耗时:${e - s}ms`);
}

/**
   * 查询所有
   * @param storeCode
   * @param selectP
   */
async findAll(storeCode: string, selectP?: string[]) {
    return await this.featureRepository
      .createQueryBuilder('feature')
      .select(selectP)
      .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
      .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
      .where('store.storeCode = :storeCode', { storeCode })
      .getMany();
}

/**
   * 查询特性及其关联的图像
   * @param storeCode
   */
async findAllWithImage(storeCode: string): Promise<Feature[]> {
    return await this.featureRepository.createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
      .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
      .where('store.storeCode = :storeCode', { storeCode })
      .getMany();
}

/**
   * 删除门店所有数据
   * @param storeCode
   */
async removeAll(storeCode: string): Promise<void> {
    const store = await this.storeRepository.findOne({ where: { storeCode }, relations: ['storeFeatures'] });
    if (!store) {
      return;
    }
    // 批量删除 storeFeatures 和 store
    if (store.storeFeatures.length > 0) {
      await this.storeFeatureRepository
      .query('DELETE FROM store_feature WHERE id IN (?)', );
    }
    await this.storeRepository.remove(store);// 删除 store
    const unreferencedFeatures = await this.featureRepository
      .createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .leftJoin('feature.storeFeatures', 'storeFeature')
      .where('storeFeature.id IS NULL') // 这里的条件确保我们只选择那些没有其他引用的 feature
      .getMany();
    // 批量删除未引用的 features
    if (unreferencedFeatures.length > 0) {
      for (const feature of unreferencedFeatures) {
      await this.remove(feature);
      }
    }
    await this.redisService.del(`${storeCode}-featureDatabase`);
    await this.syncRedis(storeCode);
}

/**
   * 预测
   * @param file
   * @param num
   * @param storeCode
   * @param justPredict
   * @param needList
   */
async predict(
    file: Express.Multer.File,
    num: string = '5',
    storeCode: string,
    justPredict: string = 'false',
    needList: boolean = false,
) {
    const PYTHON_SERVICE_URL = 'http://localhost:5000/predict'; // Python service URL
    const REDIS_KEY_PREFIX = '-featureDatabase';
    const startTime = Date.now();
    const numInt = parseInt(num);
    const isJustPredict = justPredict === 'true';

    try {
      // Prepare form data
      const formData = new FormData();
      formData.append('file', file.buffer, file.originalname);
      formData.append('storeCode', storeCode);
      formData.append('justPredict', justPredict);

      // Send request to Python service
      const response = await firstValueFrom(this.httpService.post(PYTHON_SERVICE_URL, formData));
      const { features, index, predictTime } = response.data;

      if (isJustPredict) {
      return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Retrieve feature database from Redis
      const featureDatabaseStr = await this.redisService.get(`${storeCode}${REDIS_KEY_PREFIX}`);
      if (!featureDatabaseStr) {
      return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Parse the Redis result and filter the IDs
      const featureDatabase = JSON.parse(featureDatabaseStr);
      const ids = index
      .map((idx: number) => featureDatabase);

      if (!ids.length) {
      return this.buildResponse([], features, predictTime, startTime, numInt);
      }

      // Query for features in the database
      const featureList = await this.featureRepository.createQueryBuilder('feature')
      .where('feature.id IN (:...ids)', { ids })
      .orderBy(`FIELD(feature.id, ${ids.map((id: any) => `'${id}'`).join(', ')})`, 'ASC')
      .getMany();

      // Filter to ensure unique labels
      const uniqueList = this.filterUniqueFeatures(featureList, numInt);

      const result = this.buildResponse(uniqueList, features, predictTime, startTime, numInt);
      return needList ? { ...result, featureList: featureList.map(({ features, ...rest }) => rest) } : result;
    } catch (error) {
      throw new Error(`Prediction failed: ${error.message}`);
    }
}

private filterUniqueFeatures(featureList: any[], limit: number) {
    const uniqueList = [];
    for (const feature of featureList) {
      if (!uniqueList.some(f => f.label === feature.label)) {
      uniqueList.push(feature);
      }
      if (uniqueList.length === limit) break;
    }
    return uniqueList;
}

private buildResponse(list: any[], features: any, predictTime: string, startTime: number, num: number) {
    const totalTime = `${Date.now() - startTime}ms`;
    return {
      predictTime,
      [`top${num}`]: list.map(({ features, ...rest }) => rest),
      features,
      totalTime,
    };
}

/**
   * 计算余弦相似度
   * @param vecA
   * @param vecB
   */
cosineSimilarity(vecA: number[], vecB: number[]): number {
    if (vecA.length !== vecB.length) {
      throw new Error('Vectors must be of the same length');
    }
    const dotProduct = vecA.reduce((sum, value, index) => sum + value * vecB, 0);
    const magnitudeA = Math.sqrt(vecA.reduce((sum, value) => sum + value * value, 0));
    const magnitudeB = Math.sqrt(vecB.reduce((sum, value) => sum + value * value, 0));
    return dotProduct / (magnitudeA * magnitudeB);
}

/**
   * 查找相似
   * @param inputFeatures
   * @param num
   * @param storeCode
   */
async findTopNSimilar(inputFeatures: number[], num: number, storeCode: string): Promise<{
    label: string;
    similarity: number
}[]> {
    const featureDatabaseStr = await this.redisService.get(`${storeCode}-featureDatabase`);
    if (!featureDatabaseStr) {
      return [];
    }
    const featureDatabase = JSON.parse(featureDatabaseStr);
    const similarities = featureDatabase.map(({ features, label }) => {
      let similarity = 0;
      if (features) {
      similarity = this.cosineSimilarity(inputFeatures, features);
      }
      return { label: label as string, similarity: similarity as number };
    });

    similarities.sort((a: { similarity: number; }, b: { similarity: number; }) => b.similarity - a.similarity);

    const uniqueLabels = new Set<string>();
    const topNUnique: { label: string; similarity: number; }[] = [];
    for (const item of similarities) {
      if (!uniqueLabels.has(item.label as string)) {
      uniqueLabels.add(item.label);
      item.similarity = Math.round(item.similarity * 100) / 100;
      topNUnique.push(item);
      if (topNUnique.length === num) break;
      }
    }
    return topNUnique;
}

/**
   * 根据名称查询
   * @param label
   * @param storeCode
   */
async getByName(label: string, storeCode: string): Promise<Feature[]> {
    return await this.featureRepository
      .createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
      .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
      .where('store.storeCode = :storeCode', { storeCode })
      .andWhere('feature.label = :label', { label })
      .getMany();
}

/**
   * 根据名称向量个数查询
   * @param label
   * @param storeCode
   */
async getCountByLabel(label: string, storeCode: string): Promise<number> {
    return await this.featureRepository
      .createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .innerJoin(StoreFeature, 'storeFeature', 'feature.id = storeFeature.featureId')
      .innerJoin(Store, 'store', 'storeFeature.storeCode = store.storeCode')
      .where('store.storeCode = :storeCode', { storeCode })
      .andWhere('feature.label = :label', { label })
      .getCount();
}

/**
   * 批量学习
   * @param files
   * @param createFeatureDto
   */
async batchStudy(files: Express.Multer.File[], createFeatureDto: CreateFeatureDto) {
    const list = [];
    for (const file of files) {
      try {
      const { features: f } = await this.predict(file, '5', createFeatureDto.storeCode, 'true');
      const feature = await this.create(file, {
          ...createFeatureDto,
          features: f,
      }, false);
      // 创建一个副本,不包含 `features` 属性
      const { features, ...featureWithoutFeatures } = feature;
      // 将不包含 `features` 属性的对象推送到数组中
      list.push(featureWithoutFeatures);
      } catch (e) {
      console.error(e);
      }
    }
    await this.syncRedis(createFeatureDto.storeCode);
    return list;
}

/**
   * 删除门店的特征值数据
   * @param feature
   */
async remove(feature: Feature) {
    await this.featureRepository.remove(feature);
    await this.imgRepository.remove(feature.img);
}

/**
   * 批量删除
   * @param ids
   * @param storeCode
   */
async batchRemove(ids: string, storeCode: string) {
    const list = ids.split(',').map(id => +id);
    // 批量查询所有相关的 Feature
    const features = await this.featureRepository.find({
      where: { id: In(list) },
      relations: ['img', 'storeFeatures'],
    });
    for (const feature of features) {
      feature && await this.remove(feature);
      await this.storeFeatureRepository.remove(feature.storeFeatures);
    }
    await this.syncRedis(storeCode);
}

/**
   * 导入数据
   * @param storeCode
   * @param sourceStoreCode
   * @param storeName
   */
async importData(storeCode: string, sourceStoreCode?: string, storeName?: string) {
    let storeFeatures = [];
    // 第一步:查询指定 storeCode 关联的所有 featureId
    const storeFeatureIds = await this.storeFeatureRepository
      .createQueryBuilder('storeFeature')
      .select('storeFeature.featureId')
      .where('storeFeature.storeCode = :storeCode', { storeCode })
      .getRawMany();

    // 提取出 featureId 列表
    const featureIdsToExclude = storeFeatureIds.map(row => row.featureId);
    let distinctFeatureIds = [];
    if (featureIdsToExclude.length === 0) {
      distinctFeatureIds = await this.storeFeatureRepository
      .createQueryBuilder('storeFeature')
      .select('DISTINCT storeFeature.featureId')// 确保 featureId 唯一
      .getRawMany();
    } else {
      // 第二步:排除这些 featureId,并确保 featureId 唯一
      distinctFeatureIds = await this.storeFeatureRepository
      .createQueryBuilder('storeFeature')
      .select('DISTINCT storeFeature.featureId')// 确保 featureId 唯一
      .where('storeFeature.featureId NOT IN (:...featureIdsToExclude)', { featureIdsToExclude })// 排除 featureId
      .getRawMany();
    }
    const featureIds = distinctFeatureIds.map(record => record.featureId);
    if (!sourceStoreCode) {
      storeFeatures = await this.featureRepository
      .createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .whereInIds(featureIds)
      .getMany();
    } else {
      storeFeatures = await this.featureRepository
      .createQueryBuilder('feature')
      .leftJoinAndSelect('feature.img', 'img')
      .innerJoin('feature.storeFeatures', 'storeFeatures')
      .whereInIds(featureIds)
      .andWhere('storeFeatures.storeCode = :storeCode', { storeCode: sourceStoreCode })// 使用参数化查询
      .getMany();
    }
    let targetStore = await this.storeRepository.findOne({ where: { storeCode: storeCode } });
    if (!targetStore) {
      targetStore = this.storeRepository.create({
      storeCode: storeCode,
      storeName: storeName,
      });
      await this.storeRepository.save(targetStore);
    }
    // Create new StoreFeature records for the target storeCode
    const newStoreFeatures = storeFeatures.map((feature: Feature) => ({
      store: targetStore,
      feature, // Reuse the existing feature
    }));
    // Save new StoreFeature records
    const storeFeatureInstances = this.storeFeatureRepository.create(newStoreFeatures);
    await this.storeFeatureRepository.save(storeFeatureInstances);
    await this.syncRedis(storeCode);
    return `同步完成,共导入${storeFeatures.length}条数据`;
}

async init() {
    const distinctStoreCodes = await this.storeRepository
      .createQueryBuilder('store')
      .select('store.storeCode')
      .distinct(true)
      .getRawMany();
    const syncList = [];
    for (const row of distinctStoreCodes) {
      const storeCode = row.store_storeCode;
      syncList.push(this.syncRedis(storeCode));
    }
    await Promise.all(syncList);
    console.log('初始化完成');
}
}


[*]效果:并没有提升多少,但幸亏关系更清楚,为之后的拓展打了底子
实现ivf的动态增删改查


[*]结论:ivf无法在不训练只增长的情况下举行新增向量的识别,所以每次新增向量必须重新举行训练和添加
[*]python端ivf改造
detect.py(识别和同步方法)
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
import numpy as np
import time
import gc
from ivf import IVFPQ
from feature import get_feature_by_store_code
import orjson
from concurrent.futures import ThreadPoolExecutor

# 加载预训练的 MobileNetV2 模型,不包含顶部的分类层
model = MobileNetV2(input_shape=(224, 224, 3), weights='imagenet', include_top=False, pooling='avg')


class MainDetect:
    # 初始化
    def __init__(self):
      super().__init__()
      # 模型初始化
      self.image_id = None
      self.image_features = None
      self.model = tf.keras.models.load_model("models/custom/my-model.h5")
      self.ivfObj = {}

    def classify_image(self, image_data, store_code, just_predict):
      # Load and preprocess image
      img = tf.image.decode_image(image_data, channels=3)
      img = tf.image.resize(img, )
      img = tf.expand_dims(img, axis=0)# Add batch dimension

      # Run model prediction
      start_time = time.time()
      outputs = model.predict(img)
      # outputs = self.model.predict(outputs)
      # prediction = tf.divide(outputs, tf.norm(outputs))
      i = []
      if just_predict == "false":
            if store_code + '-featureDatabase' in self.ivfObj:
                i = self.ivfObj.search(outputs)
                i = i.flatten().tolist()

      end_time = time.time()

      # Calculate elapsed time
      elapsed_time = end_time - start_time

      # Flatten the outputs and return them
      # output_data = prediction.numpy().flatten().tolist()
      output_data = outputs.flatten().tolist()

      # Force garbage collection to free up memory
      del img, outputs, end_time, start_time# Ensure variables are deleted
      gc.collect()

      return {"outputs": output_data, "time": f"{elapsed_time * 1000:.2f}ms", "index": i}

    def sync(self, store_code):
      if store_code + '-featureDatabase' in self.ivfObj:
            del self.ivfObj
      data = get_feature_by_store_code(store_code)

      if len(data) == 0:
            return []
      else:
            def parse_features(item):
                return orjson.loads(item['features'])

            with ThreadPoolExecutor() as executor:
                features_list = list(executor.map(parse_features, data))
            # 提取所有特征并转换为 NumPy 数组
            features = np.array(features_list, dtype=np.float32)
            self.ivfObj = IVFPQ(features)
            ids = for item in data]
            return ids

ivf.py(ivf构造)
import faiss
import numpy as np

num_threads = 8
faiss.omp_set_num_threads(num_threads)


class IVFPQ:
    def __init__(self, features, nlist=100, m=16, n_bits=8):
      d = features.shape
      # 创建量化器
      quantizer = faiss.IndexFlatL2(d)# 使用L2距离进行量化
      self.index = faiss.IndexIVFFlat(quantizer, d, nlist)
      # self.index = faiss.IndexIVFPQ(quantizer, d, nlist, m, n_bits)
      # 训练索引
      count = 3900
      if features.size >= count * d:
            self.index.train(features)
            if features.size > 1000 * d:
                batch_size = 1000# 每次处理1000个特征
                for i in range(0, len(features), batch_size):
                  self.index.add(features)
            else:
                self.index.add(features)
      else:
            points = int(count - features.size / d)
            np.random.seed(points)
            xb = np.random.random((points, d)).astype('float32')# 模拟数据库中的特征向量
            combined_features = np.vstack((features, xb))# Stack them vertically
            # 训练索引
            self.index.train(combined_features)
            self.index.add(combined_features)# 将特征向量添加到索引中

    def search(self, xq, k=100):
      d, i = self.index.search(xq, k)
      return i

    def add(self, xb):
      self.index.add(xb)

    def train(self, xb):
      self.index.train(xb)

    def sync(self, features):
      for i in range(len(features)):
            self.add(features)

结语

这个项目优化到这差不多告一段落了,后续还有啥优化点会继续跟进,稍后会把整个架构图和功能点都梳理一遍

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
页: [1]
查看完整版本: 果蔬识别系统性能优化之路(五)