以下代码实现了基于深度学习的鱼类数量检测功能。它使用了预训练的 Faster R-CNN 模型来识别图像中的鱼类,并通过边界框标记出每条鱼的位置。代码包含了以下功能:
1.单张图片检测
2.文件夹批量检测
3.检测结果可视化
4.统计分析功能
使用时,你需要安装必要的依赖库,如 PyTorch、OpenCV、Matplotlib 等。如果需要更高精度的检测效果,可以考虑使用专门针对鱼类训练的模型,或者在自己的鱼类数据集上对现有模型进行微调。
import cv2
import numpy as np
import torch
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt
import os
from datetime import datetime
class FishDetector:
def init(self, model_path=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
"""初始化鱼类检测模型和参数"""
self.device = device
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 使用预训练的Faster R-CNN模型
self.model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
self.model.to(device)
self.model.eval()
# 鱼类的类别ID (根据COCO数据集,鱼的ID为16)
self.fish_class_id = 16
# 检测参数
self.confidence_threshold = 0.7
self.nms_threshold = 0.3
# 结果存储
self.results = []
def detect_fish(self, image_path):
"""检测单张图片中的鱼类"""
try:
# 读取图像
image = Image.open(image_path).convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# 模型推理
with torch.no_grad():
predictions = self.model(image_tensor)
# 处理预测结果
boxes = predictions[0]['boxes'].cpu().numpy()
scores = predictions[0]['scores'].cpu().numpy()
labels = predictions[0]['labels'].cpu().numpy()
# 筛选鱼类检测结果
fish_indices = np.where(
(labels == self.fish_class_id) &
(scores > self.confidence_threshold)
)[0]
fish_boxes = boxes[fish_indices]
fish_scores = scores[fish_indices]
# 应用非极大值抑制
keep_indices = cv2.dnn.NMSBoxes(
fish_boxes.tolist(),
fish_scores.tolist(),
self.confidence_threshold,
self.nms_threshold
)
final_boxes = fish_boxes[keep_indices] if len(keep_indices) > 0 else []
fish_count = len(final_boxes)
# 记录结果
result = {
'image_path': image_path,
'fish_count': fish_count,
'boxes': final_boxes,
'timestamp': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
self.results.append(result)
return result
except Exception as e:
print(f"Error processing image {image_path}: {str(e)}")
return None
def detect_folder(self, folder_path):
"""检测文件夹中所有图片的鱼类"""
all_results = []
for filename in os.listdir(folder_path):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
file_path = os.path.join(folder_path, filename)
result = self.detect_fish(file_path)
if result:
all_results.append(result)
return all_results
def visualize_result(self, image_path, output_path=None):
"""可视化检测结果"""
# 找到对应的结果
result = next((r for r in self.results if r['image_path'] == image_path), None)
if not result:
print(f"No result found for {image_path}")
return
# 读取原图
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 绘制边界框
for box in result['boxes']:
x1, y1, x2, y2 = map(int, box)
cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# 添加鱼的数量文本
cv2.putText(
image,
f"Fish Count: {result['fish_count']}",
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(0, 0, 255),
2
)
# 显示或保存结果
if output_path:
plt.imsave(output_path, image)
print(f"Visualization saved to {output_path}")
else:
plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.axis('off')
plt.show()
def get_statistics(self):
"""获取检测统计信息"""
if not self.results:
return "No results available"
total_images = len(self.results)
total_fish = sum(r['fish_count'] for r in self.results)
fish_per_image = [r['fish_count'] for r in self.results]
stats = {
'total_images': total_images,
'total_fish': total_fish,
'average_fish_per_image': total_fish / total_images if total_images > 0 else 0,
'max_fish_in_single_image': max(fish_per_image) if fish_per_image else 0,
'min_fish_in_single_image': min(fish_per_image) if fish_per_image else 0
}
return stats
def main():
"""主函数示例"""
# 创建检测器实例
detector = FishDetector()
# 检测单张图片
single_result = detector.detect_fish("path/to/your/fish_image.jpg")
if single_result:
print(f"Detected {single_result['fish_count']} fish in the image.")
detector.visualize_result("path/to/your/fish_image.jpg", "output.jpg")
# 检测文件夹中的所有图片
folder_results = detector.detect_folder("path/to/your/image/folder")
for result in folder_results:
print(f"Image: {os.path.basename(result['image_path'])}, Fish Count: {result['fish_count']}")
# 获取统计信息
stats = detector.get_statistics()
print("\nDetection Statistics:")
for key, value in stats.items():
print(f"{key}: {value}")
if name == "main":
main()