计算机视觉cv2入门之实时手势检测

打印 上一主题 下一主题

主题 1771|帖子 1771|积分 5315


        前边我们已经讲解了使用cv2进行图像预处置惩罚以及针对实时视频流文件的使用方法,这里我们通过实时手势检测这一案例来学习和实操一下。

大致思绪


  • 根据手势的种类以及指定手势图片数量来构建一个自己的手势图片数据集
  • CNN模型训练手势图片数据集
  • 使用训练好的模型进行实时猜测
手势图片数据集的构建

        经典的手势图片数据集有很多,但是都比较大,下载费时且模型训练时间长,因此这里我决定自行采集手势图片来构建一个小型数据集。手势图片的获取方法比较简单,就是使用cv2.VideoCapture函数打开摄像头来进行采集。这里我把我的方法分享给大家。
采集手势图片

  1. import cv2
  2. import os
  3. DATASET_DIR='GesturesPhotos'#保存所有待采集手势的图片的文件夹的路径
  4. gesture_kinds=5#手势种类:单手可以是1-10,我这里是1-5
  5. photo_num=10#图片数量
  6. classes=list(range(1,gesture_kinds+1,1))#使用1-gesture_kinds来表示所有待预测类别
  7. ###############################################
  8. gestures=photo_num//gesture_kinds*classes#photo_num//gesture_kinds=10//5=2,2*[1,2,3,4,5]=[1,2,3,4,5,1,2,3,4,5]
  9. gestures.extend(classes[:photo_num%gesture_kinds])#photo_num%5=10%5=0,extend([:0])相当于extend([])
  10. '''
  11. 经过这两步运算,gestures为长度与图片数量一致且由类别构成的列表
  12. gestures主要用来标定每次采集的种类
  13. 比如,gesture_kinds=5,photo_num=7,手势种类为5,那么这7次要采集的顺序为[1,2,3,4,5,1,2]
  14. '''
  15. ###############################################
  16. os.makedirs(DATASET_DIR, exist_ok=True)#exist_ok=True可以避免二次采集时重建新文件夹
  17. def capture_gestures(gesture:str,count:int):
  18.     '''
  19.     Args:
  20.         gesture:每次采集的手势,要标记在视频中,防止忘记采集的手势是多少导致实际类别与真实采集结果不一致从而成为噪声!\n
  21.         count:用来命名每次保存的图片,这里直接用记录图片数量来命名\n
  22.     '''
  23.     cv2.namedWindow('Data Collection', cv2.WND_PROP_FULLSCREEN)
  24.     cv2.setWindowProperty('Data Collection', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
  25.     cap=cv2.VideoCapture(0)
  26.     print(f'采集手势{gesture}(按ESC保存并退出)')
  27.     while True:
  28.         ret,frame=cap.read()
  29.         if not ret: break
  30.         roi=frame[160:440,50:250]#roi区域,可以自行修改
  31.         cv2.rectangle(frame, (50,160),(250,440),(0,255,0), 2)#roi区域处绘制方框
  32.         cv2.putText(frame,text=f'No.{count+1} Photo gesture {gesture}',org=(250,100),fontScale=2,thickness=5,color=(0,0,255),fontFace=1)
  33.         cv2.imshow(f'Data Collection',frame)
  34.         key=cv2.waitKey(1)
  35.         if key==27:#按下ESC保存并退出
  36.             img_path=f'{DATASET_DIR}/{count}.jpg'
  37.             cv2.imwrite(img_path,roi)
  38.             break
  39.     cap.release()
  40.     cv2.destroyAllWindows()
  41. for i in range(len(gestures)):
  42.     capture_gestures(gestures[i],i)
复制代码
         运行上述代码后,便可以开始采集手势图片了,这里我使用上述代码统共采集了200张图片用于后续CNN模型的训练。 
说明

        采集时,将右手放置在视频中的绿色框内,尽大概的放置在中央,gesture后的数字表示当前要表示的手势种类。如果采集时出现错误,那么只需要删除掉原来的图片,自行指定新的类别(gesture)以及原来图片的编号,调用一次capture_gestures函数重新采集即可。

采集结果 



采集结果(0-199 40组1-5的手势图片)

        这里我没有对背景进行太多处置惩罚,如果有大佬愿意,可以实验将采集到的图片的背景虚化,突出手掌主体。

 数据预处置惩罚

           这里的数据预处置惩罚主要就是将我们的图像数据划分训练集与测试集后转换为tensor范例的DataLoder。
  1. #数据预处理
  2. from torch.utils.data import Dataset, DataLoader
  3. import torch
  4. import torch.nn as nn
  5. import torch.optim as optim
  6. from torchvision import transforms
  7. class GestureDataset(Dataset):
  8.     def __init__(self, data_dir=DATASET_DIR,gesture_kinds=gesture_kinds,transform=None):
  9.         self.data_dir = data_dir
  10.         self.transform = transform
  11.         self.image_paths = []
  12.         self.labels = []
  13.         
  14.         # 读取数据集
  15.         for img_name in os.listdir(data_dir):
  16.             if img_name.endswith('.jpg'):
  17.                 self.image_paths.append(os.path.join(data_dir, img_name))
  18.                 self.labels.append(int(img_name.split('.')[0])%gesture_kinds)#0-4对于1-5
  19.    
  20.     def __len__(self):
  21.         return len(self.image_paths)
  22.    
  23.     def __getitem__(self, idx):
  24.         img_path=self.image_paths[idx]
  25.         image=cv2.imread(img_path)
  26.         image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # 转换为RGB
  27.         
  28.         label=self.labels[idx]
  29.         
  30.         if self.transform:
  31.             image=self.transform(image)
  32.             
  33.         return image, label
  34. def process_data(data_dir=DATASET_DIR, batch_size=4):
  35.     # 数据预处理
  36.     transform = transforms.Compose([
  37.         transforms.ToPILImage(),
  38.         transforms.Resize((64, 64)),
  39.         transforms.ToTensor(),
  40.         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  41.     ])
  42.    
  43.     dataset=GestureDataset(data_dir, transform=transform)
  44.     train_size=int(0.8 * len(dataset))
  45.     test_size=len(dataset) - train_size
  46.     train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
  47.     train_loader=DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
  48.     test_loader=DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
  49.     return train_loader, test_loader
复制代码
CNN模型训练

        考虑到我的数据集比较少且该分类题目比较简单,所以这里我的模型也没有太复杂只是使用了2层卷积使用。倘若你的数据集比较大,分类种类比较多,可以实验使用一些其他的CNN模型,比如mobilenet,resnet等。
  1. #CNN模型
  2. class GestureCNN(nn.Module):
  3.     def __init__(self, num_classes=5):
  4.         super(GestureCNN, self).__init__()
  5.         self.conv1=nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
  6.         self.relu=nn.ReLU()
  7.         self.maxpool=nn.MaxPool2d(kernel_size=2, stride=2)
  8.         self.conv2=nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
  9.         self.fc1=nn.Linear(32*16*16, 128)
  10.         self.fc2=nn.Linear(128, num_classes)
  11.         
  12.     def forward(self, x):
  13.         x=self.conv1(x)
  14.         x=self.relu(x)
  15.         x=self.maxpool(x)
  16.         x=self.conv2(x)
  17.         x=self.relu(x)
  18.         x=self.maxpool(x)
  19.         x=x.view(x.size(0), -1)
  20.         x=self.fc1(x)
  21.         x=self.relu(x)
  22.         x=self.fc2(x)
  23.         return x
  24. def train_model(train_loader, test_loader, num_epochs=10):
  25.     device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  26.     model=GestureCNN(num_classes=5).to(device)
  27.     criterion=nn.CrossEntropyLoss()
  28.     optimizer=optim.Adam(model.parameters(), lr=0.001)
  29.    
  30.     for epoch in range(num_epochs):
  31.         model.train()
  32.         running_loss=0.0
  33.         correct=0
  34.         total=0
  35.         
  36.         for images, labels in train_loader:
  37.             images=images.to(device)
  38.             labels=labels.to(device)
  39.             
  40.             optimizer.zero_grad()
  41.             outputs=model(images)
  42.             loss=criterion(outputs, labels)
  43.             loss.backward()
  44.             optimizer.step()
  45.             
  46.             running_loss+=loss.item()
  47.             _, predicted=torch.max(outputs.data, 1)
  48.             total+=labels.size(0)
  49.             correct+=(predicted==labels).sum().item()
  50.         
  51.         train_loss = running_loss / len(train_loader)
  52.         train_acc = 100 * correct / total
  53.         
  54.         # 测试集评估
  55.         model.eval()
  56.         test_correct = 0
  57.         test_total = 0
  58.         with torch.no_grad():
  59.             for images, labels in test_loader:
  60.                 images=images.to(device)
  61.                 labels=labels.to(device)
  62.                 outputs=model(images)
  63.                 _, predicted=torch.max(outputs.data, 1)
  64.                 test_total+=labels.size(0)
  65.                 test_correct+=(predicted==labels).sum().item()
  66.         
  67.         test_acc=100*test_correct/test_total
  68.         
  69.         print(f'Epoch [{epoch+1}/{num_epochs}], '
  70.               f'Train Loss: {train_loss:.4f}, '
  71.               f'Train Acc: {train_acc:.2f}%, '
  72.               f'Test Acc: {test_acc:.2f}%')
  73.    
  74.     # 保存模型
  75.     torch.save(model.state_dict(), 'gesture_cnn.pth')
  76.     print('训练完成,模型已保存为 gesture_cnn.pth')
  77.     return model
复制代码
实时猜测 

        实时猜测的思绪是:打开摄像头,获取实时视频流文件中的每一帧图片中的手势,使用训练好的模型猜测并将结果标注在视频流文件的每一帧上。
  1. #实时预测
  2. def realtime_prediction(model_path='gesture_cnn.pth'):
  3.     device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  4.     #加载模型
  5.     model = GestureCNN(num_classes=5).to(device)
  6.     model.load_state_dict(torch.load(model_path))
  7.     model.eval()
  8.    
  9.     #预处理
  10.     transform=transforms.Compose([
  11.         transforms.ToPILImage(),
  12.         transforms.Resize((64, 64)),
  13.         transforms.ToTensor(),
  14.         transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  15.     ])
  16.     cap=cv2.VideoCapture(0)
  17.     cv2.namedWindow('Gesture Recognition', cv2.WND_PROP_FULLSCREEN)
  18.     cv2.setWindowProperty('Gesture Recognition', cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
  19.     CLASSES=gestures
  20.     with torch.no_grad():
  21.         while True:
  22.             ret, frame = cap.read()
  23.             if not ret:
  24.                 break  
  25.             # 手势检测区域
  26.             roi = frame[160:440, 50:250]
  27.             cv2.rectangle(frame, (50, 160), (250, 440), (0, 255, 0), 2)
  28.             
  29.             try:
  30.                 input_tensor = transform(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB)).unsqueeze(0).to(device)
  31.                 output = model(input_tensor)
  32.                 _, pred=torch.max(output, 1)
  33.                 probabilities=torch.nn.functional.softmax(output[0], dim=0)
  34.                 confidence, pred=torch.max(probabilities, 0)
  35.                 confidence=confidence.item()*100 #转换为百分比
  36.                 confidence=round(confidence,2)
  37.                 cv2.putText(frame, f'Prediction: {CLASSES[pred.item()]}', (50, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
  38.                 cv2.putText(frame,f'confidence:{confidence}',(70,70),cv2.FONT_HERSHEY_SIMPLEX,0.5, (0, 0, 255), 2)
  39.             except Exception as e:
  40.                 print(f"预测错误: {e}")
  41.             
  42.             cv2.imshow('Gesture Recognition', frame)
  43.             
  44.             if cv2.waitKey(1)==27:
  45.                 break
  46.    
  47.     cap.release()
  48.     cv2.destroyAllWindows()
  49. train_loader, test_loader = process_data()
  50. model=train_model(train_loader, test_loader, num_epochs=10)
  51. realtime_prediction()
复制代码
 
结果:

 
cv2不支持中文字体,因此只能使用英文来标注…… 
总结



        以上便是计算机视觉cv2入门之实时手势检测的全部内容,如果你感到本文对你有用,还劳驾各位一键三连支持一下博主。

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

祗疼妳一个

论坛元老
这个人很懒什么都没写!
快速回复 返回顶部 返回列表