Commit 1e0485c8 by Yuhaibo

feat: 完成模型训练流程优化和结果格式转换

- 优化了模型保存流程,训练完成的模型直接保存在detection_model目录,避免重复移动操作
- 实现了检测结果格式转换功能,确保新旧格式兼容性,支持多区域液位检测数据转换
- 增强了页面同步机制,训练完成后自动刷新模型集管理、模型测试和训练页面的模型列表
- 添加了自动生成模型描述文件功能,包含模型基础信息、技术特性和使用说明
- 改
parents c3d7c4a2 ed53cf16
epoch,time,train/box_loss,train/seg_loss,train/cls_loss,train/dfl_loss,metrics/precision(B),metrics/recall(B),metrics/mAP50(B),metrics/mAP50-95(B),metrics/precision(M),metrics/recall(M),metrics/mAP50(M),metrics/mAP50-95(M),val/box_loss,val/seg_loss,val/cls_loss,val/dfl_loss,lr/pg0,lr/pg1,lr/pg2
1,928.936,0.83523,1.3075,0.46573,1.05864,0.87981,0.85416,0.91902,0.70839,0.8784,0.85371,0.91642,0.68473,0.89994,1.39776,0.62225,1.11979,0.0700103,0.00333218,0.00333218
2,1745.63,0.87012,1.35322,0.51061,1.07759,0.92936,0.89401,0.95387,0.72198,0.93047,0.89384,0.9537,0.70371,0.88167,1.35963,0.54179,1.08754,0.0393505,0.00600563,0.00600563
3,2568.99,0.9039,1.39713,0.56215,1.10027,0.92337,0.89082,0.95288,0.72543,0.92112,0.89467,0.95309,0.71126,0.88062,1.34921,0.54834,1.08423,0.00803057,0.00801908,0.00801908
......@@ -705,7 +705,7 @@ class ModelSetHandler:
if model_path.endswith('.pt'):
model_type = "PyTorch (YOLOv5/v8)"
elif model_path.endswith('.dat'):
model_type = "加密模型 (.dat)"
model_type = ".dat"
else:
model_type = "未知格式"
......
......@@ -136,6 +136,7 @@ class ModelTestThread(QThread):
# 读取标注数据
self.progress_updated.emit(45, "正在读取标注数据...")
import yaml
with open(annotation_file, 'r', encoding='utf-8') as f:
try:
annotation_data = yaml.safe_load(f)
......@@ -158,6 +159,7 @@ class ModelTestThread(QThread):
# 配置标注数据
self.progress_updated.emit(70, "正在配置检测参数...")
<<<<<<< HEAD
# 提取标注数据
test_model_data = annotation_data.get('test_model', {})
boxes = test_model_data.get('boxes', [])
......@@ -172,6 +174,14 @@ class ModelTestThread(QThread):
actual_heights = test_model_data.get('actual_heights', [20.0] * len(boxes))
# 配置检测引擎
=======
# 从annotation_data中提取配置参数
boxes = annotation_data.get('boxes', [])
fixed_bottoms = annotation_data.get('fixed_bottoms', [])
fixed_tops = annotation_data.get('fixed_tops', [])
actual_heights = annotation_data.get('actual_heights', [20.0] * len(boxes))
>>>>>>> ed53cf16e35be94b62ee54dddc99b5e7e2dbe200
detection_engine.configure(boxes, fixed_bottoms, fixed_tops, actual_heights)
# 执行检测
......@@ -182,10 +192,14 @@ class ModelTestThread(QThread):
if detection_result is None or not detection_result.get('success', False):
raise RuntimeError("检测结果为空或检测失败")
self._detection_result = detection_result
# 转换检测结果格式以兼容现有代码
converted_result = self._convertDetectionResult(detection_result)
self._detection_result = converted_result
self.progress_updated.emit(90, "正在保存测试结果...")
self._saveImageTestResults(model_path, test_frame, detection_result, annotation_file)
# 通过handler调用保存方法
handler = self.test_params['handler']
handler._saveImageTestResults(model_path, test_frame, converted_result, annotation_file)
except Exception as e:
raise
......@@ -212,6 +226,59 @@ class ModelTestThread(QThread):
def get_detection_result(self):
"""获取检测结果"""
return self._detection_result
def _convertDetectionResult(self, detection_result):
"""
转换检测结果格式以兼容现有代码
Args:
detection_result: detect方法返回的结果格式
{
'liquid_line_positions': {
0: {'y': y坐标, 'height_mm': 高度毫米, 'height_px': 高度像素},
1: {...},
...
},
'success': bool
}
Returns:
dict: 转换后的结果格式,包含 liquid_level_mm 字段
"""
try:
if not detection_result or not detection_result.get('success', False):
return {'liquid_level_mm': 0.0, 'success': False}
liquid_positions = detection_result.get('liquid_line_positions', {})
if not liquid_positions:
return {'liquid_level_mm': 0.0, 'success': False}
# 取第一个检测区域的液位高度
first_position = next(iter(liquid_positions.values()))
liquid_level_mm = first_position.get('height_mm', 0.0)
# 构建兼容格式的结果
converted_result = {
'liquid_level_mm': liquid_level_mm,
'success': True,
'areas': {}
}
# 转换所有区域的数据
for idx, position_data in liquid_positions.items():
area_name = f'区域{idx + 1}'
converted_result['areas'][area_name] = {
'liquid_height': position_data.get('height_mm', 0.0),
'y_position': position_data.get('y', 0),
'height_px': position_data.get('height_px', 0)
}
return converted_result
except Exception as e:
print(f"[结果转换] 转换检测结果失败: {e}")
return {'liquid_level_mm': 0.0, 'success': False}
class ModelTestHandler:
......@@ -1793,7 +1860,6 @@ class ModelTestHandler:
<h4 style="margin: 0 0 10px 0; color: #333333; font-size: 16px; font-weight: 500;">检测结果视频</h4>
<video width="100%" height="auto" controls style="border: none; border-radius: 6px; max-height: 400px; background: #f8f9fa;">
<source src="file:///{video_path_formatted}" type="video/mp4">
您的浏览器不支持视频播放
</video>
</div>
......
This source diff could not be displayed because it is too large. You can view the blob instead.
# -*- coding: utf-8 -*-
"""
训练工作线程
处理模型训练的后台线程
"""
import os
import yaml
import json
import struct
import hashlib
from pathlib import Path
from qtpy import QtCore
# 尝试导入 pyqtSignal,如果失败则使用 Signal
try:
from PyQt5.QtCore import pyqtSignal
except ImportError:
try:
from PyQt6.QtCore import pyqtSignal
except ImportError:
# 如果都失败,使用 QtCore.Signal
from qtpy.QtCore import Signal as pyqtSignal
from qtpy.QtCore import QThread
# 导入统一的路径管理函数
try:
from ...database.config import get_project_root, get_temp_models_dir, get_train_dir
except (ImportError, ValueError):
try:
from database.config import get_project_root, get_temp_models_dir, get_train_dir
except ImportError:
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from database.config import get_project_root, get_temp_models_dir, get_train_dir
MODEL_FILE_SIGNATURE = b'LDS_MODEL_FILE'
MODEL_FILE_VERSION = 1
MODEL_ENCRYPTION_KEY = "liquid_detection_system_2024"
class TrainingWorker(QThread):
"""训练工作线程"""
# 信号定义
log_output = pyqtSignal(str) # 日志输出信号
training_finished = pyqtSignal(bool) # 训练完成信号
training_progress = pyqtSignal(int, dict) # 训练进度信号 (epoch, loss_dict)
def __init__(self, training_params):
super().__init__()
self.training_params = training_params
self.is_running = True
self.train_config = None
self.training_report = {
"status": "init",
"start_time": None,
"end_time": None,
"exp_name": training_params.get("exp_name"),
"params": training_params,
"device": training_params.get("device"),
"weights_dir": None,
"converted_dat_files": [],
"error": None,
}
# 加载训练配置
self._loadTrainingConfig()
def _loadTrainingConfig(self):
"""加载训练配置"""
try:
import os
import json
current_dir = os.path.dirname(os.path.abspath(__file__))
config_dir = os.path.join(current_dir, "..", "..", "database", "config", "train_configs")
config_file_path = os.path.join(config_dir, "default_config.json")
if not os.path.exists(config_file_path):
# 尝试使用项目根目录
try:
from database.config import get_project_root
project_root = get_project_root()
config_file_path = os.path.join(project_root, "database", "config", "train_configs", "default_config.json")
except:
pass
if os.path.exists(config_file_path):
with open(config_file_path, 'r', encoding='utf-8') as f:
self.train_config = json.load(f)
else:
self.train_config = None
except Exception as e:
self.train_config = None
def _decode_dat_model(self, dat_path):
"""
将加密的 .dat 模型解密为临时 .pt 文件
Args:
dat_path (str): .dat 模型路径
Returns:
str: 解密后的 .pt 模型路径
"""
dat_path = Path(dat_path)
if not dat_path.exists():
raise FileNotFoundError(f"模型文件不存在: {dat_path}")
# 检查文件签名,判断是否为加密文件
with open(dat_path, 'rb') as f:
signature = f.read(len(MODEL_FILE_SIGNATURE))
# 如果签名不匹配,说明这是一个直接重命名的 .pt 文件
if signature != MODEL_FILE_SIGNATURE:
print(f"[警告] {dat_path.name} 不是加密的 .dat 文件,将直接作为 .pt 文件使用")
# 直接返回原路径,YOLO 可以直接加载
return str(dat_path)
# 继续解密流程
version = struct.unpack('<I', f.read(4))[0]
if version != MODEL_FILE_VERSION:
raise ValueError(f"不支持的模型文件版本: {version}")
filename_len = struct.unpack('<I', f.read(4))[0]
_ = f.read(filename_len) # 原始文件名,当前不使用
data_len = struct.unpack('<Q', f.read(8))[0]
encrypted_data = f.read(data_len)
key_hash = hashlib.sha256(MODEL_ENCRYPTION_KEY.encode('utf-8')).digest()
decrypted = bytearray(len(encrypted_data))
key_len = len(key_hash)
for idx, byte in enumerate(encrypted_data):
decrypted[idx] = byte ^ key_hash[idx % key_len]
decrypted = bytes(decrypted)
temp_dir = Path(get_temp_models_dir())
temp_dir.mkdir(parents=True, exist_ok=True)
path_hash = hashlib.md5(str(dat_path).encode('utf-8')).hexdigest()[:8]
temp_model_path = temp_dir / f"train_{dat_path.stem}_{path_hash}.pt"
with open(temp_model_path, 'wb') as f:
f.write(decrypted)
return str(temp_model_path)
def _validateTrainingDataInThread(self, save_liquid_data_path):
"""
在线程中验证训练数据(简化版,避免UI操作)
Returns:
tuple: (是否有效, 消息)
"""
try:
if not os.path.exists(save_liquid_data_path):
return False, f"数据集配置文件不存在: {save_liquid_data_path}"
if not save_liquid_data_path.endswith('.yaml'):
return False, "数据集配置文件必须是 .yaml 格式"
# 读取配置
with open(save_liquid_data_path, 'r', encoding='utf-8') as f:
data_config = yaml.safe_load(f)
if not data_config:
return False, "数据集配置文件为空"
# 获取data.yaml所在目录
data_yaml_dir = os.path.dirname(os.path.abspath(save_liquid_data_path))
train_dir = data_config.get('train', '')
val_dir = data_config.get('val', '')
if not train_dir:
return False, "训练集路径为空"
if not val_dir:
return False, "验证集路径为空"
# 如果是相对路径,转换为相对于data.yaml的绝对路径
if not os.path.isabs(train_dir):
train_dir = os.path.join(data_yaml_dir, train_dir)
if not os.path.isabs(val_dir):
val_dir = os.path.join(data_yaml_dir, val_dir)
if not os.path.exists(train_dir):
return False, f"训练集路径不存在: {train_dir}"
if not os.path.exists(val_dir):
return False, f"验证集路径不存在: {val_dir}"
# 检查是否有图片文件
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']
train_count = sum(1 for f in os.listdir(train_dir)
if any(f.lower().endswith(ext) for ext in image_extensions))
val_count = sum(1 for f in os.listdir(val_dir)
if any(f.lower().endswith(ext) for ext in image_extensions))
if train_count == 0:
return False, f"训练集目录为空: {train_dir}"
if val_count == 0:
return False, f"验证集目录为空: {val_dir}"
return True, f"数据集验证通过 (训练: {train_count} 张, 验证: {val_count} 张)"
except Exception as e:
return False, f"验证过程出错: {str(e)}"
def run(self):
"""执行训练"""
# 初始化变量(确保finally块能访问)
original_stdout = None
original_stderr = None
temp_model_path = None
try:
import os
import sys
import io
import logging
# 根据训练设备设置环境变量
device = self.training_params.get('device', 'cpu')
if device.lower() == 'cpu':
os.environ["CUDA_VISIBLE_DEVICES"] = '-1' # 强制使用 CPU
else:
# GPU 设备:支持 '0', '0,1' 等格式
os.environ["CUDA_VISIBLE_DEVICES"] = device
# 优化环境变量设置
os.environ['YOLO_VERBOSE'] = 'True' # 允许显示训练进度
os.environ['ULTRALYTICS_AUTODOWNLOAD'] = 'False' # 禁用自动下载
os.environ['ULTRALYTICS_DATASETS_DIR'] = os.path.join(os.getcwd(), 'database', 'dataset')
# 设置日志级别以支持进度条显示
import logging
logging.getLogger('ultralytics').setLevel(logging.INFO)
logging.getLogger('yolov8').setLevel(logging.INFO)
# 确保进度条能正常显示
os.environ['TERM'] = 'xterm-256color' # 支持颜色和进度条
# 先导入YOLO,但不立即设置离线模式
# 离线模式会在验证模型文件存在后设置
from ultralytics import YOLO
# 创建日志捕获类(同步终端和UI,只显示原生进度条,单行实时更新,每轮换行)
class LogCapture:
"""捕获训练进度,同步显示到终端和UI(与终端完全一致)
- 训练过程中:单行实时更新进度条(缓存进度条,只发送最新的)
- 每轮完成(100%):保留该行并换行,下一轮从新行开始
"""
def __init__(self, signal, original_stream, log_file_path=None):
self.signal = signal
self.original = original_stream
self.buffer = ""
self._log_file_path = log_file_path
self._is_progress_line = False # 标记当前是否是进度条行
self._cached_progress = None # 缓存最新的进度条行
self._last_epoch = None # 记录上一个 epoch
def write(self, text):
import re
# 始终写入终端(保证终端显示完整)
if self.original:
try:
self.original.write(text)
self.original.flush()
except:
pass
# 同步写入到日志文件(追加)
if self._log_file_path:
try:
with open(self._log_file_path, "a", encoding="utf-8", errors="ignore") as lf:
lf.write(text)
except:
pass
# 处理文本:清理ANSI代码并发送到UI
# 移除ANSI转义序列(颜色代码等)
clean_text = re.sub(r'\x1B\[[0-?]*[ -/]*[@-~]', '', text)
# 过滤掉YOLO自动打印的验证指标行(包含mAP等)
# 这些行通常包含:Epoch, GPU_mem, box_loss, cls_loss, dfl_loss, Instances, Size, mAP50, mAP50-95等
# 示例:Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
# 1/100 3.72G 1.173 1.920 1.506 29 640
if re.search(r'(Epoch\s+GPU_mem|metrics/mAP|val/box_loss|val/cls_loss|val/dfl_loss|mAP50|mAP50-95)', clean_text, re.IGNORECASE):
# 跳过这些验证指标行,不发送到UI
return
# 检查是否包含回车符(进度条通常使用\r来覆盖同一行)
has_carriage_return = '\r' in text
# 移除回车符,但记住这是进度条行
if has_carriage_return:
clean_text = re.sub(r'\r', '', clean_text)
self._is_progress_line = True
# 如果有换行符,说明进度条行结束
if '\n' in clean_text:
self._is_progress_line = False
# 先检查是否需要过滤(扫描信息、调试日志等)
# 只过滤明确不需要的信息
skip_patterns = [
'scanning', # 数据集扫描信息
'labels.cache', # 缓存文件信息
'duplicate', # 重复标签信息
'warning:', # 警告信息
'[trainingpage]', # UI 调试日志
'[应用]', # 应用调试日志
]
should_skip = False
for pattern in skip_patterns:
if pattern in clean_text.lower():
should_skip = True
break
if should_skip:
return # 跳过这条信息
# 再检查是否是训练进度条行(优先级最高,不过滤)
# 训练进度条格式:epoch/batch 显存 损失值... 进度条 速度
# 例如:1/100 3.72G 1.173 1.92 1.506 1.253 29 640: 4% ──────────── 109/2901
# 关键特征:包含 epoch/batch、显存(G)、多个损失值、百分比
is_progress_bar = (
# 最准确的特征:包含 epoch/batch 格式、显存信息(G)、百分比和进度符号
(not '\n' in clean_text and
re.search(r'\d+/\d+', clean_text) is not None and
re.search(r'\d+\.?\d*G', clean_text) is not None and
'%' in clean_text and
('|' in clean_text or '━' in clean_text or '─' in clean_text))
)
# 发送所有有效文本到UI,包括训练信息和进度条
if clean_text.strip():
# 发送进度条或普通文本到UI
if is_progress_bar:
try:
# 检查是否达到100%(一轮完成)
is_complete = '100%' in clean_text
# 提取当前 epoch 号(格式:1/100, 2/100 等)
epoch_match = re.search(r'(\d+)/(\d+)', clean_text)
current_epoch = int(epoch_match.group(1)) if epoch_match else None
# 使用特殊标记来标识进度条
if is_complete:
# 如果达到100%,标记为完成,UI会保留这一行并换行
marked_text = "__PROGRESS_BAR_COMPLETE__" + clean_text
self.signal.emit(marked_text)
self._cached_progress = None # 清空缓存
self._last_epoch = current_epoch # 更新 epoch 记录
else:
# 关键修复:实时发送进度条,而不是缓存
# 这样用户可以看到实时的训练进度
marked_text = "__PROGRESS_BAR__" + clean_text
self.signal.emit(marked_text) # 立即发送,不缓存
self._cached_progress = marked_text # 保留缓存备用
except Exception as e:
# 如果处理进度条出错,作为普通文本发送
self.signal.emit(clean_text)
else:
# 发送之前缓存的进度条(如果有的话)
if self._cached_progress:
self.signal.emit(self._cached_progress)
self._cached_progress = None
# 发送普通训练信息到UI
self.signal.emit(clean_text)
def flush(self):
# 刷新终端
if self.original:
try:
self.original.flush()
except:
pass
# 如果缓冲区有内容,尝试发送到UI
if self.buffer and self.buffer.strip():
try:
import re
clean_text = re.sub(r'\x1B\[[0-?]*[ -/]*[@-~]', '', self.buffer)
clean_text = re.sub(r'\r', '', clean_text)
if clean_text.strip():
self.signal.emit(clean_text)
self.buffer = ""
except:
pass
# 保存原始stdout/stderr
original_stdout = sys.stdout
original_stderr = sys.stderr
# 预先准备日志目录与日志文件
try:
train_root_for_log = get_train_dir()
exp_name_for_log = self.training_params.get('exp_name', 'training_experiment')
exp_dir_for_log = os.path.join(train_root_for_log, "runs", "train", exp_name_for_log)
os.makedirs(exp_dir_for_log, exist_ok=True)
log_file_path = os.path.join(exp_dir_for_log, "training_ui.log")
# 记录到报告(存绝对路径)
self.training_report["weights_dir"] = os.path.abspath(os.path.join(exp_dir_for_log, "weights"))
except Exception:
log_file_path = None
# 重定向stdout和stderr(附带文件记录)
sys.stdout = LogCapture(self.log_output, sys.__stdout__, log_file_path)
sys.stderr = LogCapture(self.log_output, sys.__stderr__, log_file_path)
# 输出训练开始信息(简化版,不打印详细参数)
self.log_output.emit("=" * 70 + "\n")
self.log_output.emit("开始升级模型\n")
self.log_output.emit("=" * 70 + "\n\n")
# 报告:开始时间
import time as _time_mod
self.training_report["status"] = "running"
self.training_report["start_time"] = _time_mod.time()
# 验证数据集(在训练线程中再次验证,确保数据可用)
self.log_output.emit("正在验证数据集...\n")
try:
validation_result, validation_msg = self._validateTrainingDataInThread(self.training_params['save_liquid_data_path'])
if not validation_result:
self.log_output.emit(f"[ERROR] 数据集验证失败: {validation_msg}\n")
self.log_output.emit("=" * 60 + "\n")
self.training_finished.emit(False)
return
else:
self.log_output.emit(f"{validation_msg}\n\n")
except Exception as e:
self.log_output.emit(f"[WARNING] 数据集验证过程出错: {str(e)}\n")
self.log_output.emit("继续尝试训练...\n\n")
# 处理模型文件
model_path = self.training_params['base_model']
temp_model_path = None
if model_path.endswith('.dat'):
self.log_output.emit("正在处理.dat模型文件...\n")
try:
decoded_path = self._decode_dat_model(model_path)
model_path = decoded_path
temp_model_path = decoded_path
self.log_output.emit("模型处理完成\n")
except Exception as decode_error:
self.log_output.emit(f"[ERROR] 模型处理失败: {decode_error}\n")
self.training_finished.emit(False)
return
# 检查停止标志
if not self.is_running:
self.log_output.emit("[WARNING] 训练在开始前被停止\n")
return
# 加载模型
self.log_output.emit("正在加载模型...\n")
try:
# 在加载模型前验证文件存在,并设置离线模式
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 验证通过后,设置离线模式防止ultralytics尝试下载其他模型
os.environ['YOLO_OFFLINE'] = '1'
os.environ['ULTRALYTICS_OFFLINE'] = 'True'
model = YOLO(model_path)
self.log_output.emit("模型加载成功\n\n")
except Exception as model_error:
self.log_output.emit(f"[ERROR] 模型加载失败: {str(model_error)}\n")
raise model_error
# 创建训练回调
import time
epoch_start_time = [0] # 使用列表以便在闭包中修改
def on_train_start(trainer):
"""训练开始回调 - 只输出到终端,不发送到UI"""
# 记录开始时间
epoch_start_time[0] = time.time()
# 不发送任何格式化消息到UI,让LogCapture直接捕获原生输出
def on_train_batch_end(trainer):
"""训练批次结束回调 - 检查停止标志但不立即停止"""
if not self.is_running:
# 只显示提示信息,不设置stop_training标志
# 让训练继续到epoch结束
if not hasattr(trainer, '_stop_message_shown'):
print("\n用户请求停止训练...")
print("请稍候,等待当前训练轮次完成...")
trainer._stop_message_shown = True
def on_train_epoch_end(trainer):
"""训练周期结束回调 - 检查停止标志,在epoch完成后优雅停止"""
# 获取当前轮次信息
epoch = trainer.epoch + 1
total_epochs = trainer.epochs
# 如果用户请求停止,在当前epoch完成后停止
if not self.is_running:
print(f"\n当前轮次 {epoch}/{total_epochs} 已完成")
print("用户请求停止训练,正在退出...")
trainer.stop_training = True
if hasattr(trainer, 'model'):
trainer.model.training = False
# 抛出异常来终止训练,但此时当前epoch已完成
raise KeyboardInterrupt("用户停止训练")
# 重置计时器
current_time = time.time()
epoch_start_time[0] = current_time
# 只发送进度信号,不发送格式化消息到UI
# 让LogCapture直接捕获原生输出
try:
loss_dict = {}
if hasattr(trainer, 'metrics'):
if hasattr(trainer.metrics, 'box_loss'):
loss_dict['box_loss'] = float(trainer.metrics.box_loss)
if hasattr(trainer.metrics, 'cls_loss'):
loss_dict['cls_loss'] = float(trainer.metrics.cls_loss)
self.training_progress.emit(epoch, loss_dict)
except Exception as e:
pass
# 添加回调
try:
model.add_callback("on_train_start", on_train_start)
model.add_callback("on_train_batch_end", on_train_batch_end)
model.add_callback("on_train_epoch_end", on_train_epoch_end)
except Exception as e:
self.log_output.emit(f"回调添加失败: {str(e)}\n")
# 最后一次检查停止标志
if not self.is_running:
self.log_output.emit("[WARNING] 训练在开始前被停止\n")
return
self.log_output.emit("开始升级模型...\n")
self.log_output.emit("=" * 60 + "\n")
# 检查并调整batch size(防止GPU OOM)
batch_size = self.training_params['batch']
device_str = self.training_params['device']
imgsz = self.training_params['imgsz']
original_batch_size = batch_size # 保存原始batch size
# 如果使用GPU,检查显存和batch size
if device_str.lower() not in ['cpu', '-1']:
self.log_output.emit(f"检测到GPU训练(设备: {device_str})\n")
# 尝试获取GPU信息
try:
import torch
import gc
if torch.cuda.is_available():
gpu_id = int(device_str) if device_str.isdigit() else 0
gpu_name = torch.cuda.get_device_name(gpu_id)
total_memory = torch.cuda.get_device_properties(gpu_id).total_memory / (1024**3) # GB
self.log_output.emit(f"GPU型号: {gpu_name}\n")
self.log_output.emit(f"总显存: {total_memory:.2f} GB\n")
# 彻底清理显存
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
# 获取当前可用显存
try:
allocated = torch.cuda.memory_allocated(gpu_id) / (1024**3)
reserved = torch.cuda.memory_reserved(gpu_id) / (1024**3)
free_memory = total_memory - reserved
self.log_output.emit(f"当前已分配: {allocated:.2f} GB\n")
self.log_output.emit(f"当前保留: {reserved:.2f} GB\n")
self.log_output.emit(f"可用显存: {free_memory:.2f} GB\n\n")
# 根据显存大小和图像尺寸给出batch size建议
if total_memory < 6: # 6GB以下
recommended_batch = 4
recommended_imgsz = 512
elif total_memory < 12: # 6-12GB
recommended_batch = 8
recommended_imgsz = 640
else: # 12GB以上
recommended_batch = 16
recommended_imgsz = 640
# 根据图像尺寸调整建议
if imgsz > 640:
recommended_batch = max(4, recommended_batch // 2)
elif imgsz > 512:
recommended_batch = max(4, int(recommended_batch * 0.75))
# 如果可用显存不足,进一步降低建议
if free_memory < 3.0:
recommended_batch = max(2, recommended_batch // 2)
# 检查当前设置是否合理,如果超出建议值则自动调整
if batch_size > recommended_batch:
self.log_output.emit(f"警告: 当前batch={batch_size}可能超出显存容量\n")
self.log_output.emit(f"自动调整: batch={batch_size} -> {recommended_batch}\n")
batch_size = recommended_batch
self.log_output.emit(f"建议配置: batch≤{recommended_batch}, imgsz≤{recommended_imgsz}\n\n")
elif free_memory < 2.0: # 可用显存少于2GB
self.log_output.emit(f"警告: 可用显存不足 ({free_memory:.2f} GB)\n")
# 自动降低batch size
if batch_size > 4:
new_batch = max(2, batch_size // 2)
self.log_output.emit(f"自动调整: batch={batch_size} -> {new_batch}\n")
batch_size = new_batch
self.log_output.emit(f"建议: 关闭其他程序释放显存,或进一步减小batch size\n\n")
except:
pass
except Exception as e:
self.log_output.emit(f"无法获取GPU详细信息: {str(e)}\n")
# 通用建议和自动调整
if batch_size > 8:
self.log_output.emit(f"警告: batch={batch_size} 可能导致显存不足\n")
new_batch = max(4, batch_size // 2)
self.log_output.emit(f"自动调整: batch={batch_size} -> {new_batch}\n")
batch_size = new_batch
self.log_output.emit(f"建议: 使用batch≤8以避免OOM错误\n\n")
# 开始训练(支持自动重试和batch size调整)
max_retries = 3
retry_count = 0
training_success = False
while retry_count < max_retries and not training_success:
try:
# 从配置文件读取AMP设置,如果没有则默认启用(节省显存)
amp_enabled = True # 默认启用AMP
if self.train_config and 'device_config' in self.train_config:
amp_enabled = self.train_config['device_config'].get('amp', True)
# 如果使用CPU,强制关闭AMP(CPU不支持AMP)
if device_str.lower() in ['cpu', '-1']:
amp_enabled = False
# 如果是重试,清理显存
if retry_count > 0:
self.log_output.emit(f"\n第 {retry_count} 次重试训练...\n")
try:
import torch
import gc
gc.collect()
torch.cuda.empty_cache()
torch.cuda.synchronize()
self.log_output.emit("已清理GPU显存缓存\n")
except:
pass
self.log_output.emit(f"批次大小: {batch_size}\n")
self.log_output.emit(f"训练设备: {device_str}\n")
self.log_output.emit(f"模型名称: {self.training_params['exp_name']}\n\n")
# 优化workers参数,避免多线程死锁
workers = min(self.training_params['workers'], 2) # 限制最大workers数量
if device_str.lower() in ['cpu', '-1']:
workers = 0 # CPU模式下禁用多线程数据加载
# 开始训练
try:
mission_results = model.train(
data=self.training_params['save_liquid_data_path'],
imgsz=self.training_params['imgsz'],
epochs=self.training_params['epochs'],
batch=batch_size,
workers=workers,
device=device_str,
optimizer=self.training_params['optimizer'],
close_mosaic=self.training_params['close_mosaic'],
resume=self.training_params['resume'],
project='database/train/runs/train',
name=self.training_params['exp_name'],
single_cls=self.training_params['single_cls'],
cache=False,
pretrained=self.training_params['pretrained'],
verbose=True, # 启用原生进度条显示
save_period=1, # 每个epoch都保存模型,确保用户停止时有模型文件
amp=amp_enabled,
plots=True,
exist_ok=True,
patience=100
)
except KeyboardInterrupt:
# 用户停止训练,这是正常的停止操作
self.log_output.emit("\n训练已按用户要求停止\n")
# 等待YOLO完成当前epoch并保存模型
import time
self.log_output.emit("等待当前epoch完成并保存模型...\n")
time.sleep(2) # 给YOLO时间完成保存
training_success = True # 标记为成功,因为这是用户主动停止
break # 跳出重试循环
except Exception as e:
# 如果训练失败,尝试备用方法
self.log_output.emit(f"训练启动失败: {str(e)}\n")
self.log_output.emit("尝试备用方法...\n")
try:
mission_results = model.train(
data=self.training_params['save_liquid_data_path'],
epochs=self.training_params['epochs'],
batch=max(1, batch_size // 2),
device=device_str,
workers=0,
verbose=True,
save_period=1 # 每个epoch都保存模型
)
except KeyboardInterrupt:
# 备用方法中用户也停止了训练
self.log_output.emit("\n训练已按用户要求停止\n")
# 等待YOLO完成当前epoch并保存模型
import time
self.log_output.emit("等待当前epoch完成并保存模型...\n")
time.sleep(2) # 给YOLO时间完成保存
training_success = True
break
# 训练成功
training_success = True
# 保存基本结果路径到报告
try:
# Ultralytics 会把保存目录置于 model.trainer.save_dir
save_dir = getattr(getattr(model, "trainer", None), "save_dir", None)
if save_dir:
save_dir_abs = os.path.abspath(str(save_dir))
weights_dir = os.path.abspath(os.path.join(save_dir_abs, "weights"))
self.training_report["weights_dir"] = weights_dir
# 立即转换PT文件为DAT格式并删除PT文件
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
except:
pass
break # 跳出重试循环
except RuntimeError as runtime_error:
error_msg = str(runtime_error)
# 检查是否是CUDA OOM错误
if 'out of memory' in error_msg.lower() or 'cuda' in error_msg.lower():
# 如果是OOM错误且还有重试机会,自动降低batch size重试
if retry_count < max_retries - 1:
retry_count += 1
# 降低batch size
if batch_size > 1:
new_batch = max(1, batch_size // 2)
self.log_output.emit(f"\n" + "="*70 + "\n")
self.log_output.emit(f"GPU显存不足(OOM)错误!\n\n")
self.log_output.emit(f"自动降低batch size: {batch_size} -> {new_batch}\n")
self.log_output.emit(f"准备重试训练(第 {retry_count}/{max_retries-1} 次)...\n")
self.log_output.emit("="*70 + "\n\n")
batch_size = new_batch
continue # 重试
else:
# batch size已经是1,无法再降低
self.log_output.emit(f"\n" + "="*70 + "\n")
self.log_output.emit(f"GPU显存不足(OOM)错误!\n\n")
self.log_output.emit(f"batch size已经是1,无法继续降低\n")
self.log_output.emit(f"请尝试:\n")
self.log_output.emit(f" 1. 减小图像尺寸(当前: {imgsz})\n")
self.log_output.emit(f" 2. 关闭数据缓存\n")
self.log_output.emit(f" 3. 减少workers数量(当前: {self.training_params['workers']})\n")
self.log_output.emit(f" 4. 关闭其他占用GPU的程序\n")
self.log_output.emit("="*70 + "\n")
self.training_finished.emit(False)
raise runtime_error
else:
# 重试次数用完,输出详细错误信息并抛出异常
self.log_output.emit(f"\n" + "="*70 + "\n")
self.log_output.emit(f"GPU显存不足(OOM)错误!\n\n")
self.log_output.emit(f"已重试 {max_retries-1} 次,仍无法解决显存问题\n")
raise runtime_error
else:
# 其他运行时错误,直接抛出
raise runtime_error
except KeyboardInterrupt as kb_error:
# 用户停止训练的异常
self.log_output.emit(f"\n" + "="*60 + "\n")
self.log_output.emit("训练已按用户要求停止\n")
self.log_output.emit("="*60 + "\n")
# 强制保存当前模型
try:
self.log_output.emit("正在保存当前训练进度...\n")
weights_dir = self.training_report.get("weights_dir")
if weights_dir and os.path.exists(weights_dir):
last_pt = os.path.join(weights_dir, "last.pt")
# 方法1:直接保存模型权重(不依赖results.csv)
saved = False
if hasattr(model, 'save'):
try:
model.save(last_pt)
saved = True
self.log_output.emit(f"✓ 模型已保存到: {last_pt}\n")
except Exception as save_error:
self.log_output.emit(f"⚠ model.save()失败: {save_error},尝试备用方法...\n")
# 方法2:备用方法 - 保存checkpoint
if not saved and hasattr(model, 'trainer') and model.trainer:
try:
import torch
ckpt = {
'epoch': model.trainer.epoch if hasattr(model.trainer, 'epoch') else 0,
'model': model.model.state_dict() if hasattr(model, 'model') else model.state_dict(),
}
torch.save(ckpt, last_pt)
saved = True
self.log_output.emit(f"✓ checkpoint已保存到: {last_pt}\n")
except Exception as ckpt_error:
self.log_output.emit(f"⚠ checkpoint保存失败: {ckpt_error}\n")
if not saved:
self.log_output.emit("⚠ 所有保存方法均失败\n")
else:
self.log_output.emit(f"⚠ 权重目录不存在: {weights_dir}\n")
except Exception as save_error:
self.log_output.emit(f"⚠ 保存模型失败: {save_error}\n")
self.training_report["status"] = "stopped_by_user"
# 标记为用户手动停止
self._is_user_stopped = True
# 用户主动停止发送 False,但在 _onTrainingFinished 中会根据 _is_user_stopped 判断是否进入继续模式
self.training_finished.emit(False)
return # 直接返回,不继续执行
except Exception as train_error:
# 其他异常,直接抛出
raise train_error
# 如果训练成功,继续后续处理
if training_success:
# 训练完成
if self.is_running:
self.log_output.emit("\n" + "="*60 + "\n")
self.log_output.emit(" 训练正常完成!\n")
self.log_output.emit("="*60 + "\n")
# 标记报告
self.training_report["status"] = "success"
# 尝试转换pt->dat后,将列表加入报告
try:
if self.training_params.get('exp_name'):
# 这里不能直接访问外层 Handler 的方法,仅标记占位;实际转换在 _onTrainingFinished 中执行
# 因此我们在报告里预留字段,稍后 _onTrainingFinished 会覆盖写入最终报告
self.training_report.setdefault("converted_dat_files", [])
except Exception:
pass
self.training_finished.emit(True)
else:
# 用户停止训练(is_running=False)
self.log_output.emit("\n" + "="*60 + "\n")
self.log_output.emit("训练已按用户要求停止\n")
self.log_output.emit("="*60 + "\n")
# 强制保存当前模型
try:
self.log_output.emit("正在保存当前训练进度...\n")
weights_dir = self.training_report.get("weights_dir")
if weights_dir and os.path.exists(weights_dir):
last_pt = os.path.join(weights_dir, "last.pt")
# 方法1:直接保存模型权重(不依赖results.csv)
saved = False
if hasattr(model, 'save'):
try:
model.save(last_pt)
saved = True
self.log_output.emit(f"✓ 模型已保存到: {last_pt}\n")
except Exception as save_error:
self.log_output.emit(f"⚠ model.save()失败: {save_error},尝试备用方法...\n")
# 方法2:备用方法 - 保存checkpoint
if not saved and hasattr(model, 'trainer') and model.trainer:
try:
import torch
ckpt = {
'epoch': model.trainer.epoch if hasattr(model.trainer, 'epoch') else 0,
'model': model.model.state_dict() if hasattr(model, 'model') else model.state_dict(),
}
torch.save(ckpt, last_pt)
saved = True
self.log_output.emit(f"✓ checkpoint已保存到: {last_pt}\n")
except Exception as ckpt_error:
self.log_output.emit(f"⚠ checkpoint保存失败: {ckpt_error}\n")
if not saved:
self.log_output.emit("⚠ 所有保存方法均失败\n")
else:
self.log_output.emit(f"⚠ 权重目录不存在: {weights_dir}\n")
except Exception as save_error:
self.log_output.emit(f"⚠ 保存模型失败: {save_error}\n")
self.training_report["status"] = "stopped_by_user"
self._is_user_stopped = True
# 用户主动停止发送 False,但在 _onTrainingFinished 中会根据 _is_user_stopped 判断是否进入继续模式
self.training_finished.emit(False)
except KeyboardInterrupt as kb_error:
# 用户停止训练的异常(最外层捕获)
self.log_output.emit(f"\n" + "="*60 + "\n")
self.log_output.emit("训练已按用户要求停止\n")
self.log_output.emit("="*60 + "\n")
# 强制保存当前模型
try:
self.log_output.emit("正在保存当前训练进度...\n")
if 'model' in locals():
weights_dir = self.training_report.get("weights_dir")
if weights_dir and os.path.exists(weights_dir):
last_pt = os.path.join(weights_dir, "last.pt")
# 方法1:直接保存模型权重(不依赖results.csv)
saved = False
if hasattr(model, 'save'):
try:
model.save(last_pt)
saved = True
self.log_output.emit(f"✓ 模型已保存到: {last_pt}\n")
except Exception as save_error:
self.log_output.emit(f"⚠ model.save()失败: {save_error},尝试备用方法...\n")
# 方法2:备用方法 - 保存checkpoint
if not saved and hasattr(model, 'trainer') and model.trainer:
try:
import torch
ckpt = {
'epoch': model.trainer.epoch if hasattr(model.trainer, 'epoch') else 0,
'model': model.model.state_dict() if hasattr(model, 'model') else model.state_dict(),
}
torch.save(ckpt, last_pt)
saved = True
self.log_output.emit(f"✓ checkpoint已保存到: {last_pt}\n")
except Exception as ckpt_error:
self.log_output.emit(f"⚠ checkpoint保存失败: {ckpt_error}\n")
if not saved:
self.log_output.emit("⚠ 所有保存方法均失败\n")
else:
self.log_output.emit(f"⚠ 权重目录不存在: {weights_dir}\n")
else:
self.log_output.emit("⚠ model对象不存在,无法保存\n")
except Exception as save_error:
self.log_output.emit(f"⚠ 保存模型失败: {save_error}\n")
self.training_report["status"] = "stopped_by_user"
# 标记为用户手动停止,确保按钮状态正确切换
self._is_user_stopped = True
# 用户主动停止发送 False,但在 _onTrainingFinished 中会根据 _is_user_stopped 判断是否进入继续模式
self.training_finished.emit(False)
except Exception as e:
error_msg = str(e)
self.log_output.emit(f"\n" + "="*60 + "\n")
self.log_output.emit(f" 升级失败: {error_msg}\n")
self.log_output.emit("="*60 + "\n")
# 检查常见错误
error_lower = error_msg.lower()
if 'dataset' in error_lower or 'images not found' in error_lower or 'missing path' in error_lower:
self.log_output.emit(f"\n 数据集路径错误!\n")
self.log_output.emit(f" 请检查 data.yaml 中的 train 和 val 路径是否正确。\n")
self.log_output.emit(f" 确保路径下存在图片文件。\n")
if 'file not found' in error_lower or 'no such file' in error_lower:
self.log_output.emit(f"\n 文件未找到错误!\n")
self.log_output.emit(f" 请检查数据集路径是否正确。\n")
# 输出详细错误信息
import traceback
full_traceback = traceback.format_exc()
self.log_output.emit(f"\n详细错误信息:\n{full_traceback}\n")
# 标记报告
self.training_report["status"] = "failed"
self.training_report["error"] = error_msg
self.training_finished.emit(False)
finally:
# 记录结束时间并落盘报告
import time as _time_mod2, json as _json_mod2
self.training_report["end_time"] = _time_mod2.time()
# 写入 report 到权重目录上层(若存在)
try:
exp_name_for_report = self.training_params.get('exp_name', 'training_experiment')
train_root_for_report = get_train_dir()
exp_dir_for_report = os.path.join(train_root_for_report, "runs", "train", exp_name_for_report)
os.makedirs(exp_dir_for_report, exist_ok=True)
report_path = os.path.join(exp_dir_for_report, "training_report.json")
with open(report_path, "w", encoding="utf-8") as rf:
_json_mod2.dump(self.training_report, rf, ensure_ascii=False, indent=2)
except Exception:
pass
# 恢复原始stdout/stderr
import sys
if original_stdout is not None and original_stderr is not None:
try:
sys.stdout = original_stdout
sys.stderr = original_stderr
except Exception as e:
pass
# 清理临时文件
if temp_model_path:
import os
if os.path.exists(temp_model_path):
try:
os.remove(temp_model_path)
except Exception as e:
pass
def stop_training(self):
"""停止训练"""
self.is_running = False
......@@ -120,7 +120,7 @@ class ModelSetPage(QtWidgets.QWidget):
# 标题
title = QtWidgets.QLabel("模型集管理")
title.setStyleSheet("font-size: 13pt; font-weight: bold;")
FontManager.applyToWidget(title, weight=FontManager.WEIGHT_BOLD) # 使用全局默认字体,加粗
header_layout.addWidget(title)
header_layout.addStretch()
......@@ -246,10 +246,26 @@ class ModelSetPage(QtWidgets.QWidget):
content_layout.addWidget(left_widget)
# ========== 中间主内容区(左右分栏:模型列表 + 文本显示) ==========
center_widget = QtWidgets.QWidget()
center_main_layout = QtWidgets.QHBoxLayout(center_widget)
center_main_layout.setContentsMargins(0, 0, 0, 0)
center_main_layout.setSpacing(8)
# 使用QSplitter实现可缩放的分隔符
center_splitter = QtWidgets.QSplitter(Qt.Horizontal)
center_splitter.setChildrenCollapsible(False) # 防止面板被完全折叠
# 设置splitter样式
center_splitter.setStyleSheet("""
QSplitter::handle {
background-color: #e0e0e0;
border: 1px solid #c0c0c0;
width: 6px;
margin: 2px;
border-radius: 3px;
}
QSplitter::handle:hover {
background-color: #d0d0d0;
}
QSplitter::handle:pressed {
background-color: #b0b0b0;
}
""")
# ===== 左侧:模型列表区域 =====
left_list_widget = QtWidgets.QWidget()
......@@ -338,17 +354,22 @@ class ModelSetPage(QtWidgets.QWidget):
""")
self.model_info_text.setPlaceholderText("选择模型后将显示模型文件夹内的txt文件内容...")
# 应用全局字体管理
FontManager.applyToWidget(self.model_info_text, size=10)
FontManager.applyToWidget(self.model_info_text) # 使用全局默认字体
right_text_layout.addWidget(self.model_info_text)
# 将左右两部分添加到中间主布局(1:1比例)
center_main_layout.addWidget(left_list_widget, 1)
center_main_layout.addWidget(right_text_widget, 1)
# 将左右两部分添加到splitter
center_splitter.addWidget(left_list_widget)
center_splitter.addWidget(right_text_widget)
# 设置初始比例(1:1)
center_splitter.setSizes([400, 400])
center_splitter.setStretchFactor(0, 1) # 左侧可伸缩
center_splitter.setStretchFactor(1, 1) # 右侧可伸缩
# 🔥 设置中间区域的宽度(主要内容区)- 响应式布局
center_widget.setMinimumWidth(scale_w(800))
center_splitter.setMinimumWidth(scale_w(800))
content_layout.addWidget(center_widget, 3)
content_layout.addWidget(center_splitter, 3)
main_layout.addLayout(content_layout)
......@@ -558,7 +579,7 @@ class ModelSetPage(QtWidgets.QWidget):
if '预训练模型' not in model_type and 'YOLO' not in model_type:
model_type = "PyTorch (YOLOv8)"
elif model_file.endswith('.dat'):
model_type = "加密模型 (.dat)"
model_type = ".dat"
elif model_file.endswith('.onnx'):
model_type = "ONNX"
......@@ -1216,7 +1237,6 @@ class ModelSetPage(QtWidgets.QWidget):
border-radius: 4px;
}
QLabel {
font-size: 10pt;
padding: 3px;
}
""")
......@@ -1294,6 +1314,9 @@ class ModelSetPage(QtWidgets.QWidget):
layout.addLayout(button_layout)
# 应用全局字体管理器到对话框及其所有子控件
FontManager.applyToWidgetRecursive(dialog)
# 显示对话框
dialog.exec_()
......@@ -1381,8 +1404,8 @@ class ModelSetPage(QtWidgets.QWidget):
# 2. 从配置文件提取通道模型
channel_models = self._extractChannelModels(config)
# 3. 扫描模型目录
scanned_models = self._scanModelDirectory()
# 3. 扫描模型目录(使用统一的getDetectionModels方法确保名称一致)
scanned_models = self.getDetectionModels()
# 4. 合并所有模型信息
all_models = self._mergeModelInfo(channel_models, scanned_models)
......@@ -1469,8 +1492,11 @@ class ModelSetPage(QtWidgets.QWidget):
# 3. 检查模型路径是否存在
if model_path and os.path.exists(model_path):
# 使用统一的命名逻辑:从模型路径推导出标准名称
model_name = self._getModelNameFromPath(model_path)
models.append({
'name': f"{channel_name}模型",
'name': model_name, # 使用统一的模型名称
'path': model_path,
'channel': channel_key,
'channel_name': channel_name,
......@@ -1480,19 +1506,64 @@ class ModelSetPage(QtWidgets.QWidget):
return models
def _getModelNameFromPath(self, model_path):
"""从模型路径推导出统一的模型名称"""
try:
from pathlib import Path
path = Path(model_path)
# 检查是否在detection_model目录下
if 'detection_model' in path.parts:
# 找到detection_model目录的索引
parts = path.parts
detection_index = -1
for i, part in enumerate(parts):
if part == 'detection_model':
detection_index = i
break
if detection_index >= 0 and detection_index + 1 < len(parts):
# 获取模型ID目录名
model_id = parts[detection_index + 1]
# 尝试读取config.yaml获取模型名称
model_dir = path.parent
config_locations = [
model_dir / "training_results" / "config.yaml",
model_dir / "config.yaml"
]
for config_file in config_locations:
if config_file.exists():
try:
import yaml
with open(config_file, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
return config_data.get('model_name', f"模型_{model_id}")
except:
continue
# 如果没有配置文件,使用默认格式
return f"模型_{model_id}"
# 如果不在detection_model目录下,使用文件名
return path.stem
except Exception:
# 出错时返回文件名
return Path(model_path).stem
def _scanModelDirectory(self):
"""扫描模型目录获取所有模型文件(优先detection_model)"""
"""扫描模型目录获取所有模型文件(只从detection_model加载)"""
models = []
try:
# 获取模型目录路径
current_dir = Path(__file__).parent.parent.parent
# 扫描多个模型目录(优先detection_model)
# 只扫描detection_model目录,确保数据一致性
model_dirs = [
(current_dir / "database" / "model" / "detection_model", "检测模型", True), # 优先级最高
(current_dir / "database" / "model" / "train_model", "训练模型", False),
(current_dir / "database" / "model" / "test_model", "测试模型", False)
(current_dir / "database" / "model" / "detection_model", "检测模型", True), # 唯一数据源
]
for model_dir, dir_type, is_primary in model_dirs:
......@@ -1604,6 +1675,104 @@ class ModelSetPage(QtWidgets.QWidget):
return models
@staticmethod
def getDetectionModels():
"""静态方法:获取detection_model目录下的所有模型,供其他页面使用"""
models = []
try:
# 获取模型目录路径
current_dir = Path(__file__).parent.parent.parent
detection_model_dir = current_dir / "database" / "model" / "detection_model"
if not detection_model_dir.exists():
return models
# 遍历所有子目录,包含数字和非数字目录
all_subdirs = [d for d in detection_model_dir.iterdir() if d.is_dir()]
# 分离数字目录和非数字目录
digit_subdirs = [d for d in all_subdirs if d.name.isdigit()]
non_digit_subdirs = [d for d in all_subdirs if not d.name.isdigit()]
# 数字目录按数字降序排列,非数字目录按字母排序
sorted_digit_subdirs = sorted(digit_subdirs, key=lambda x: int(x.name), reverse=True)
sorted_non_digit_subdirs = sorted(non_digit_subdirs, key=lambda x: x.name)
# 合并:数字目录在前(最新的在前),非数字目录在后
sorted_subdirs = sorted_digit_subdirs + sorted_non_digit_subdirs
for subdir in sorted_subdirs:
# 新结构:.dat文件直接在根目录下
# 查找.dat文件(优先在根目录,兼容旧的weights子目录)
dat_files = list(subdir.glob("*.dat"))
# 如果根目录没有.dat文件,检查weights子目录(兼容旧结构)
if not dat_files:
weights_dir = subdir / "weights"
if weights_dir.exists():
dat_files = list(weights_dir.glob("*.dat"))
if not dat_files:
continue
# 优先选择best.dat,然后是其他.dat文件
model_file = None
for dat_file in dat_files:
if dat_file.name.startswith('best.'):
model_file = dat_file
break
if not model_file:
model_file = dat_files[0]
# 尝试读取config.yaml获取模型名称
# 新结构:config.yaml在training_results子目录中
model_name = None
config_locations = [
subdir / "training_results" / "config.yaml", # 新结构
subdir / "config.yaml" # 兼容旧结构
]
for config_file in config_locations:
if config_file.exists():
try:
import yaml
with open(config_file, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
model_name = config_data.get('model_name', f"模型_{subdir.name}")
break
except:
continue
if not model_name:
model_name = f"模型_{subdir.name}"
# 获取文件大小
file_size = "未知"
try:
size_bytes = model_file.stat().st_size
if size_bytes < 1024 * 1024:
file_size = f"{size_bytes / 1024:.1f} KB"
else:
file_size = f"{size_bytes / (1024 * 1024):.1f} MB"
except:
pass
models.append({
'name': model_name,
'path': str(model_file),
'size': file_size,
'type': '.dat',
'source': '检测模型',
'model_id': subdir.name,
'is_primary': True
})
except Exception as e:
print(f"[错误] 获取detection_model模型失败: {e}")
return models
def _mergeModelInfo(self, channel_models, scanned_models):
"""合并模型信息,避免重复(优先detection_model)"""
all_models = []
......@@ -1669,7 +1838,7 @@ class ModelSetPage(QtWidgets.QWidget):
if model_path.endswith('.pt'):
model_type = "PyTorch (YOLOv5/v8)"
elif model_path.endswith('.dat'):
model_type = "加密模型 (.dat)"
model_type = ".dat"
else:
model_type = "未知格式"
......@@ -1793,7 +1962,7 @@ class ModelSetPage(QtWidgets.QWidget):
return False
def _loadModelTxtFiles(self, model_name):
"""读取并显示模型文件夹内的txt文件"""
"""显示模型最近一次升级的参数和训练效果信息"""
try:
# 清空文本显示区域
self.model_info_text.clear()
......@@ -1817,38 +1986,139 @@ class ModelSetPage(QtWidgets.QWidget):
self.model_info_text.setPlainText(f"模型目录不存在:\n{model_dir}")
return
# 查找目录中的所有txt文件
txt_files = list(model_dir.glob("*.txt"))
if not txt_files:
self.model_info_text.setPlainText(f"模型目录中没有找到txt文件:\n{model_dir}")
return
# 读取并显示所有txt文件的内容
# 构建模型训练信息
content_parts = []
content_parts.append(f"模型目录: {model_dir}\n")
content_parts.append("=" * 60 + "\n\n")
for txt_file in sorted(txt_files):
content_parts.append(f"【文件: {txt_file.name}】\n")
content_parts.append("-" * 60 + "\n")
# 首先检查是否有txt文件,如果有则优先显示txt文件内容
txt_files = []
# 新结构:搜索training_results目录的txt文件
training_results_dir = model_dir / "training_results"
if training_results_dir.exists():
training_results_files = list(training_results_dir.glob("*.txt"))
txt_files.extend(training_results_files)
# 兼容旧结构:搜索模型同级目录的txt文件
current_dir_files = list(model_dir.glob("*.txt"))
txt_files.extend(current_dir_files)
# 兼容旧结构:搜索上一级目录的txt文件
parent_dir = model_dir.parent
if parent_dir.exists():
parent_dir_files = list(parent_dir.glob("*.txt"))
txt_files.extend(parent_dir_files)
# 如果有txt文件,直接显示txt文件内容(优先显示模型描述文件)
if txt_files:
# 优先显示模型描述文件
description_files = [f for f in txt_files if '模型描述' in f.name]
other_files = [f for f in txt_files if '模型描述' not in f.name]
try:
# 尝试使用UTF-8编码读取
with open(txt_file, 'r', encoding='utf-8') as f:
file_content = f.read()
except UnicodeDecodeError:
# 如果UTF-8失败,尝试GBK编码
# 按优先级排序:模型描述文件在前
sorted_files = description_files + sorted(other_files)
for txt_file in sorted_files:
try:
with open(txt_file, 'r', encoding='gbk') as f:
file_content = f.read()
except Exception as e:
file_content = f"无法读取文件(编码错误): {str(e)}"
except Exception as e:
file_content = f"读取文件时出错: {str(e)}"
with open(txt_file, 'r', encoding='utf-8') as f:
file_content = f.read().strip()
if file_content:
if txt_file.parent == model_dir:
file_label = txt_file.name
else:
file_label = f"{txt_file.parent.name}/{txt_file.name}"
# 如果是模型描述文件,直接显示内容,不加标题
if '模型描述' in txt_file.name:
content_parts.append(f"{file_content}\n\n")
else:
content_parts.append(f"=== {file_label} ===\n")
content_parts.append(f"{file_content}\n\n")
except:
try:
with open(txt_file, 'r', encoding='gbk') as f:
file_content = f.read().strip()
if file_content:
if txt_file.parent == model_dir:
file_label = txt_file.name
else:
file_label = f"{txt_file.parent.name}/{txt_file.name}"
# 如果是模型描述文件,直接显示内容,不加标题
if '模型描述' in txt_file.name:
content_parts.append(f"{file_content}\n\n")
else:
content_parts.append(f"=== {file_label} ===\n")
content_parts.append(f"{file_content}\n\n")
except:
pass
# 如果找到了txt文件,直接显示,不再添加其他自动生成的信息
full_content = "".join(content_parts)
self.model_info_text.setPlainText(full_content)
# 滚动到顶部
cursor = self.model_info_text.textCursor()
cursor.movePosition(QtGui.QTextCursor.Start)
self.model_info_text.setTextCursor(cursor)
return
else:
# 如果没有txt文件,则显示基础信息
content_parts.append(f"【模型升级信息】 - {model_name}\n")
content_parts.append("=" * 60 + "\n\n")
content_parts.append("基础信息\n")
content_parts.append("-" * 30 + "\n")
content_parts.append(f"模型名称: {model_name}\n")
content_parts.append(f"模型类型: {model_info.get('type', '未知')}\n")
content_parts.append(f"模型路径: {model_path}\n")
content_parts.append(f"文件大小: {model_info.get('size', '未知')}\n")
content_parts.append(file_content)
content_parts.append("\n\n" + "=" * 60 + "\n\n")
# 获取文件修改时间
try:
if os.path.exists(model_path):
import time
mtime = os.path.getmtime(model_path)
mod_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime))
content_parts.append(f"最后修改: {mod_time}\n")
except:
pass
content_parts.append("\n")
# 2. 训练参数信息
training_info = self._getTrainingParameters(model_dir)
if training_info:
content_parts.append("训练参数\n")
content_parts.append("-" * 30 + "\n")
for key, value in training_info.items():
content_parts.append(f"{key}: {value}\n")
content_parts.append("\n")
# 3. 训练效果信息
results_info = self._getTrainingResults(model_dir)
if results_info:
content_parts.append("训练效果\n")
content_parts.append("-" * 30 + "\n")
for key, value in results_info.items():
content_parts.append(f"{key}: {value}\n")
content_parts.append("\n")
# 4. 模型评估
evaluation = self._evaluateModelPerformance(results_info)
if evaluation:
content_parts.append("效果评估\n")
content_parts.append("-" * 30 + "\n")
content_parts.append(f"{evaluation}\n\n")
# 如果没有找到训练信息,显示基础信息
if not training_info and not results_info:
content_parts.append("说明\n")
content_parts.append("-" * 30 + "\n")
content_parts.append("未找到详细的训练记录文件。\n")
content_parts.append("这可能是预训练模型或外部导入的模型。\n\n")
content_parts.append("模型仍可正常使用,如需查看训练效果,\n")
content_parts.append("建议重新进行模型升级训练。\n")
# 显示所有内容
full_content = "".join(content_parts)
......@@ -1860,7 +2130,199 @@ class ModelSetPage(QtWidgets.QWidget):
self.model_info_text.setTextCursor(cursor)
except Exception as e:
self.model_info_text.setPlainText(f"读取txt文件时出错:\n{str(e)}")
self.model_info_text.setPlainText(f"读取模型信息时出错:\n{str(e)}")
def _getTrainingParameters(self, model_dir):
"""获取训练参数信息"""
training_params = {}
try:
# 查找 args.yaml 文件(YOLO训练参数)
args_file = model_dir / "args.yaml"
if args_file.exists():
with open(args_file, 'r', encoding='utf-8') as f:
args_data = yaml.safe_load(f)
if args_data:
training_params["训练轮数"] = args_data.get('epochs', '未知')
training_params["批次大小"] = args_data.get('batch', '未知')
training_params["图像尺寸"] = args_data.get('imgsz', '未知')
training_params["学习率"] = args_data.get('lr0', '未知')
training_params["优化器"] = args_data.get('optimizer', '未知')
training_params["数据集"] = args_data.get('data', '未知')
training_params["设备"] = args_data.get('device', '未知')
training_params["工作进程"] = args_data.get('workers', '未知')
# 查找其他可能的配置文件
config_files = list(model_dir.glob("*config*.yaml")) + list(model_dir.glob("*config*.yml"))
for config_file in config_files:
if config_file.name != "args.yaml":
try:
with open(config_file, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
if config_data and isinstance(config_data, dict):
# 提取一些关键参数
if 'epochs' in config_data:
training_params["训练轮数"] = config_data['epochs']
if 'batch_size' in config_data:
training_params["批次大小"] = config_data['batch_size']
except:
continue
except Exception as e:
pass
return training_params
def _getTrainingResults(self, model_dir):
"""获取训练结果信息"""
results = {}
try:
# 查找 results.csv 文件(YOLO训练结果)
results_file = model_dir / "results.csv"
if results_file.exists():
import pandas as pd
try:
df = pd.read_csv(results_file)
if not df.empty:
# 获取最后一行的结果(最终训练结果)
last_row = df.iloc[-1]
# 提取关键指标
if 'metrics/precision(B)' in df.columns:
results["精确率"] = f"{last_row['metrics/precision(B)']:.4f}"
elif 'precision' in df.columns:
results["精确率"] = f"{last_row['precision']:.4f}"
if 'metrics/recall(B)' in df.columns:
results["召回率"] = f"{last_row['metrics/recall(B)']:.4f}"
elif 'recall' in df.columns:
results["召回率"] = f"{last_row['recall']:.4f}"
if 'metrics/mAP50(B)' in df.columns:
results["mAP@0.5"] = f"{last_row['metrics/mAP50(B)']:.4f}"
elif 'mAP_0.5' in df.columns:
results["mAP@0.5"] = f"{last_row['mAP_0.5']:.4f}"
if 'metrics/mAP50-95(B)' in df.columns:
results["mAP@0.5:0.95"] = f"{last_row['metrics/mAP50-95(B)']:.4f}"
elif 'mAP_0.5:0.95' in df.columns:
results["mAP@0.5:0.95"] = f"{last_row['mAP_0.5:0.95']:.4f}"
if 'train/box_loss' in df.columns:
results["训练损失"] = f"{last_row['train/box_loss']:.4f}"
elif 'train_loss' in df.columns:
results["训练损失"] = f"{last_row['train_loss']:.4f}"
if 'val/box_loss' in df.columns:
results["验证损失"] = f"{last_row['val/box_loss']:.4f}"
elif 'val_loss' in df.columns:
results["验证损失"] = f"{last_row['val_loss']:.4f}"
# 训练轮数
if 'epoch' in df.columns:
results["完成轮数"] = f"{int(last_row['epoch']) + 1}"
except ImportError:
# 如果没有pandas,尝试手动解析CSV
with open(results_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
if len(lines) > 1: # 有标题行和数据行
headers = lines[0].strip().split(',')
last_data = lines[-1].strip().split(',')
for i, header in enumerate(headers):
if i < len(last_data):
if 'precision' in header.lower():
results["精确率"] = last_data[i]
elif 'recall' in header.lower():
results["召回率"] = last_data[i]
elif 'map50' in header.lower():
results["mAP@0.5"] = last_data[i]
except Exception as e:
pass
# 查找其他结果文件
log_files = list(model_dir.glob("*.log")) + list(model_dir.glob("*train*.txt"))
for log_file in log_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
content = f.read()
# 简单提取一些关键信息
if 'Best mAP' in content or 'best map' in content.lower():
lines = content.split('\n')
for line in lines:
if 'map' in line.lower() and ('best' in line.lower() or 'final' in line.lower()):
results["最佳mAP"] = line.strip()
break
except:
continue
except Exception as e:
pass
return results
def _evaluateModelPerformance(self, results_info):
"""评估模型性能"""
if not results_info:
return "暂无训练结果数据,无法评估模型性能。"
evaluation_parts = []
try:
# 评估mAP指标
if "mAP@0.5" in results_info:
map50 = float(results_info["mAP@0.5"])
if map50 >= 0.9:
evaluation_parts.append("检测精度: 优秀 (mAP@0.5 ≥ 0.9)")
elif map50 >= 0.8:
evaluation_parts.append("检测精度: 良好 (mAP@0.5 ≥ 0.8)")
elif map50 >= 0.7:
evaluation_parts.append("检测精度: 一般 (mAP@0.5 ≥ 0.7)")
else:
evaluation_parts.append("检测精度: 较差 (mAP@0.5 < 0.7)")
# 评估精确率和召回率
if "精确率" in results_info and "召回率" in results_info:
precision = float(results_info["精确率"])
recall = float(results_info["召回率"])
if precision >= 0.85 and recall >= 0.85:
evaluation_parts.append("平衡性能: 优秀 (精确率和召回率均 ≥ 0.85)")
elif precision >= 0.75 and recall >= 0.75:
evaluation_parts.append("平衡性能: 良好 (精确率和召回率均 ≥ 0.75)")
else:
evaluation_parts.append("平衡性能: 需要改进")
# 评估损失值
if "训练损失" in results_info and "验证损失" in results_info:
train_loss = float(results_info["训练损失"])
val_loss = float(results_info["验证损失"])
if abs(train_loss - val_loss) < 0.1:
evaluation_parts.append("模型稳定性: 良好 (训练和验证损失接近)")
elif abs(train_loss - val_loss) < 0.2:
evaluation_parts.append("模型稳定性: 一般")
else:
evaluation_parts.append("模型稳定性: 可能存在过拟合")
# 综合评估
if len(evaluation_parts) == 0:
return "数据不足,无法进行详细评估。"
# 添加使用建议
if any("优秀" in part for part in evaluation_parts):
evaluation_parts.append("\n建议: 模型表现良好,可以投入使用。")
elif any("良好" in part for part in evaluation_parts):
evaluation_parts.append("\n建议: 模型表现尚可,建议在更多数据上测试。")
else:
evaluation_parts.append("\n建议: 模型需要进一步训练优化。")
except Exception as e:
return "评估过程中出现错误,请检查训练结果数据。"
return "\n".join(evaluation_parts)
if __name__ == "__main__":
......
......@@ -580,21 +580,28 @@ class TrainingPage(QtWidgets.QWidget):
def _clearCurve(self):
"""清空曲线数据并隐藏曲线面板"""
if PYQTGRAPH_AVAILABLE and hasattr(self, 'curve_plot_widget'):
# 清空数据
try:
# 总是清空数据(无论PyQtGraph是否可用)
self.curve_data_x = []
self.curve_data_y = []
print("[曲线] 已清空曲线数据")
# 清空曲线
if self.curve_line:
self.curve_plot_widget.removeItem(self.curve_line)
self.curve_line = None
# 如果PyQtGraph可用,清空图表
if PYQTGRAPH_AVAILABLE and hasattr(self, 'curve_plot_widget') and self.curve_plot_widget:
# 清空曲线
if hasattr(self, 'curve_line') and self.curve_line:
self.curve_plot_widget.removeItem(self.curve_line)
self.curve_line = None
print("[曲线] 已清空PyQtGraph曲线")
print("[曲线] 已清空曲线数据")
# 自动隐藏曲线面板,返回到初始显示状态
self.hideCurvePanel()
print("[曲线] 曲线面板已隐藏,返回初始显示状态")
# 自动隐藏曲线面板,返回到初始显示状态
self.hideCurvePanel()
print("[曲线] 曲线面板已隐藏,返回初始显示状态")
except Exception as e:
print(f"[曲线] 清空曲线失败: {e}")
import traceback
traceback.print_exc()
def addCurvePoint(self, frame_index, height_mm):
"""添加曲线数据点
......@@ -603,39 +610,48 @@ class TrainingPage(QtWidgets.QWidget):
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
if not PYQTGRAPH_AVAILABLE or not hasattr(self, 'curve_plot_widget'):
return
try:
# 添加数据点
# 确保曲线数据列表存在
if not hasattr(self, 'curve_data_x'):
self.curve_data_x = []
if not hasattr(self, 'curve_data_y'):
self.curve_data_y = []
# 总是添加数据点到列表中(无论PyQtGraph是否可用)
self.curve_data_x.append(frame_index)
self.curve_data_y.append(height_mm)
# 如果曲线不存在,创建曲线
if self.curve_line is None:
self.curve_line = self.curve_plot_widget.plot(
self.curve_data_x,
self.curve_data_y,
pen=pg.mkPen(color='#1f77b4', width=2),
name='液位高度'
)
else:
# 更新曲线数据
self.curve_line.setData(self.curve_data_x, self.curve_data_y)
print(f"[曲线] 添加数据点: 帧{frame_index}, 液位{height_mm:.1f}mm,当前数据点数: {len(self.curve_data_x)}")
# 如果PyQtGraph可用且有curve_plot_widget,更新图表
if PYQTGRAPH_AVAILABLE and hasattr(self, 'curve_plot_widget') and self.curve_plot_widget:
# 如果曲线不存在,创建曲线
if not hasattr(self, 'curve_line') or self.curve_line is None:
self.curve_line = self.curve_plot_widget.plot(
self.curve_data_x,
self.curve_data_y,
pen=pg.mkPen(color='#1f77b4', width=2),
name='液位高度'
)
else:
# 更新曲线数据
self.curve_line.setData(self.curve_data_x, self.curve_data_y)
except Exception as e:
print(f"[曲线] 添加数据点失败: {e}")
import traceback
traceback.print_exc()
def showCurvePanel(self):
"""显示曲线面板"""
if hasattr(self, 'display_layout'):
# 切换到曲线面板(索引3:hint, display_panel, video_panel, curve_panel)
self.display_layout.setCurrentIndex(3)
if hasattr(self, 'curve_panel') and hasattr(self, 'stacked_widget'):
self.stacked_widget.setCurrentWidget(self.curve_panel)
print("[曲线] 显示曲线面板")
def hideCurvePanel(self):
"""隐藏曲线面板,返回到显示面板"""
if hasattr(self, 'display_layout'):
self.display_layout.setCurrentIndex(1) # 显示 display_panel
"""隐藏曲线面板,返回到初始显示"""
if hasattr(self, 'display_panel') and hasattr(self, 'stacked_widget'):
self.stacked_widget.setCurrentWidget(self.display_panel)
print("[曲线] 隐藏曲线面板")
def saveCurveData(self, csv_path):
"""保存曲线数据为CSV文件
......@@ -1343,56 +1359,12 @@ class TrainingPage(QtWidgets.QWidget):
# self._loadTestFileList() # 🔥 不再需要刷新测试文件列表(改用浏览方式)
def _loadBaseModelOptions(self):
"""从模型集管理页面加载基础模型选项"""
# 清空现有选项
self.base_model_combo.clear()
try:
# 尝试从父窗口获取模型集页面
if hasattr(self._parent, 'modelSetPage'):
model_set_page = self._parent.modelSetPage
# 获取模型参数字典
if hasattr(model_set_page, '_model_params'):
model_params = model_set_page._model_params
if not model_params:
self.base_model_combo.addItem("未找到模型", None)
return
# 获取默认模型
default_model = None
if hasattr(model_set_page, '_current_default_model'):
default_model = model_set_page._current_default_model
# 添加所有模型到下拉框
default_index = 0
for idx, (model_name, params) in enumerate(model_params.items()):
model_path = params.get('path', '')
# 构建显示名称
display_name = model_name
if model_name == default_model:
display_name = f"{model_name} (默认)"
default_index = idx
# 添加到下拉框,使用模型路径作为数据
self.base_model_combo.addItem(display_name, model_path)
# 设置默认选择
self.base_model_combo.setCurrentIndex(default_index)
else:
self.base_model_combo.addItem("模型集页面未初始化", None)
else:
self.base_model_combo.addItem("未找到模型集页面", None)
except Exception as e:
print(f"[基础模型] 加载失败: {e}")
import traceback
traceback.print_exc()
self.base_model_combo.addItem("加载失败", None)
"""从detection_model目录加载基础模型选项"""
# 使用统一的模型刷新方法
self.refreshModelLists()
def _loadTestModelOptions(self):
<<<<<<< HEAD
"""加载测试模型选项(从 detection_model 文件夹读取,与模型集管理页面同步)"""
# 清空现有选项
......@@ -1432,6 +1404,12 @@ class TrainingPage(QtWidgets.QWidget):
# 设置默认选择
self.test_model_combo.setCurrentIndex(default_index)
=======
"""从detection_model目录加载测试模型选项"""
# 测试模型和基础模型使用相同的数据源,无需单独加载
# refreshModelLists() 已经处理了测试模型下拉菜单
pass
>>>>>>> ed53cf16e35be94b62ee54dddc99b5e7e2dbe200
def _loadModelsFromConfig(self, project_root):
"""从配置文件加载通道模型"""
......@@ -1819,8 +1797,8 @@ class TrainingPage(QtWidgets.QWidget):
self._showNoCurveMessage()
return
# 切换到曲线面板显示
self.showCurvePanel()
# 在当前测试页面中显示曲线,而不是切换面板
self._showCurveInTestPage()
# 显示曲线信息提示
data_count = len(self.curve_data_x)
......@@ -1831,7 +1809,7 @@ class TrainingPage(QtWidgets.QWidget):
self,
"曲线信息",
f"图片测试结果:\n液位高度: {liquid_level:.1f} mm\n\n"
f"曲线已显示在左侧面板中。"
f"曲线已显示在左侧测试页面中。"
)
else:
# 视频测试
......@@ -1845,7 +1823,7 @@ class TrainingPage(QtWidgets.QWidget):
f"数据点数: {data_count} 个\n"
f"液位范围: {min_level:.1f} - {max_level:.1f} mm\n"
f"平均液位: {avg_level:.1f} mm\n\n"
f"曲线已显示在左侧面板中。"
f"曲线已显示在左侧测试页面中。"
)
except Exception as e:
......@@ -1856,6 +1834,356 @@ class TrainingPage(QtWidgets.QWidget):
f"显示曲线时发生错误:\n{str(e)}"
)
def _showCurveInTestPage(self):
"""在测试页面中显示曲线"""
try:
print(f"[曲线显示] 开始显示曲线,PyQtGraph可用: {PYQTGRAPH_AVAILABLE}")
# 检查曲线数据
if not hasattr(self, 'curve_data_x') or not hasattr(self, 'curve_data_y'):
print("[曲线显示] 错误: 缺少曲线数据属性")
self._showCurveAsText()
return
if len(self.curve_data_x) == 0 or len(self.curve_data_y) == 0:
print(f"[曲线显示] 错误: 曲线数据为空,X数据点: {len(self.curve_data_x)}, Y数据点: {len(self.curve_data_y)}")
self._showCurveAsText()
return
print(f"[曲线显示] 曲线数据检查通过,数据点数: {len(self.curve_data_x)}")
if not PYQTGRAPH_AVAILABLE:
print("[曲线显示] PyQtGraph不可用,使用文本显示")
self._showCurveAsText()
return
# 生成曲线图表的HTML内容
print("[曲线显示] 开始生成HTML内容")
curve_html = self._generateCurveHTML()
if not curve_html:
print("[曲线显示] 错误: HTML内容生成失败")
self._showCurveAsText()
return
# 在显示面板中显示曲线HTML
if hasattr(self, 'display_panel') and hasattr(self, 'display_layout'):
self.display_panel.setHtml(curve_html)
self.display_layout.setCurrentWidget(self.display_panel)
print("[曲线显示] 曲线已显示在测试页面中")
else:
print("[曲线显示] 错误: 缺少display_panel或display_layout属性")
self._showCurveAsText()
except Exception as e:
print(f"[曲线显示] 在测试页面显示曲线失败: {e}")
import traceback
traceback.print_exc()
# 降级到文本显示
self._showCurveAsText()
def _generateCurveHTML(self):
"""生成曲线的HTML内容"""
try:
print("[曲线HTML] 开始生成HTML内容")
# 保存曲线图片到临时文件
import tempfile
import os
temp_dir = tempfile.gettempdir()
curve_image_path = os.path.join(temp_dir, "test_curve_display.png")
print(f"[曲线HTML] 临时图片路径: {curve_image_path}")
# 优先尝试使用matplotlib生成曲线(更可靠)
if self._createMatplotlibCurve(curve_image_path):
print("[曲线HTML] matplotlib曲线生成成功")
# 使用PyQtGraph导出曲线图片
elif hasattr(self, 'curve_plot_widget') and self.curve_plot_widget:
print("[曲线HTML] 使用现有的curve_plot_widget导出图片")
try:
# 使用样式管理器的配置
from widgets.style_manager import CurveDisplayStyleManager
chart_width, chart_height = CurveDisplayStyleManager.getChartSize()
exporter = pg.exporters.ImageExporter(self.curve_plot_widget.plotItem)
exporter.parameters()['width'] = chart_width
exporter.parameters()['height'] = chart_height
exporter.export(curve_image_path)
print("[曲线HTML] 现有widget图片导出成功")
except Exception as e:
print(f"[曲线HTML] 现有widget图片导出失败: {e}")
self._createTempCurvePlot(curve_image_path)
else:
print("[曲线HTML] 没有现有的curve_plot_widget,创建临时plot")
# 如果没有现有的plot widget,创建一个临时的
self._createTempCurvePlot(curve_image_path)
# 检查图片是否成功生成
if not os.path.exists(curve_image_path):
print(f"[曲线HTML] 错误: 图片文件未生成: {curve_image_path}")
return self._getFallbackCurveHTML()
print(f"[曲线HTML] 图片生成成功,文件大小: {os.path.getsize(curve_image_path)} bytes")
# 生成统计信息
data_count = len(self.curve_data_x)
stats_html = ""
if data_count == 1:
# 图片测试
liquid_level = self.curve_data_y[0]
stats_html = f"""
<div style="margin-bottom: 15px; padding: 10px; background: #e8f4fd; border: 1px solid #bee5eb; border-radius: 5px;">
<h4 style="margin: 0 0 8px 0; color: #0c5460;">图片测试结果</h4>
<p style="margin: 0; color: #0c5460;"><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
</div>
"""
else:
# 视频测试
min_level = min(self.curve_data_y)
max_level = max(self.curve_data_y)
avg_level = sum(self.curve_data_y) / len(self.curve_data_y)
stats_html = f"""
<div style="margin-bottom: 15px; padding: 10px; background: #e8f4fd; border: 1px solid #bee5eb; border-radius: 5px;">
<h4 style="margin: 0 0 8px 0; color: #0c5460;">视频测试结果统计</h4>
<p style="margin: 2px 0; color: #0c5460;"><strong>数据点数:</strong> {data_count} 个</p>
<p style="margin: 2px 0; color: #0c5460;"><strong>液位范围:</strong> {min_level:.1f} - {max_level:.1f} mm</p>
<p style="margin: 2px 0; color: #0c5460;"><strong>平均液位:</strong> {avg_level:.1f} mm</p>
</div>
"""
# 使用统一的样式管理器生成HTML内容
from widgets.style_manager import CurveDisplayStyleManager
html_content = CurveDisplayStyleManager.generateCurveHTML(curve_image_path, stats_html)
return html_content
except Exception as e:
print(f"[曲线HTML] 生成HTML失败: {e}")
return self._getFallbackCurveHTML()
def _createMatplotlibCurve(self, output_path):
"""使用matplotlib创建曲线图"""
try:
print("[matplotlib曲线] 开始使用matplotlib生成曲线")
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端
# 验证数据
if not hasattr(self, 'curve_data_x') or not hasattr(self, 'curve_data_y'):
print("[matplotlib曲线] 错误: 缺少曲线数据")
return False
if len(self.curve_data_x) == 0 or len(self.curve_data_y) == 0:
print(f"[matplotlib曲线] 错误: 曲线数据为空")
return False
print(f"[matplotlib曲线] 数据验证通过,X点数: {len(self.curve_data_x)}, Y点数: {len(self.curve_data_y)}")
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'Arial']
plt.rcParams['axes.unicode_minus'] = False
# 使用样式管理器的配置
from widgets.style_manager import CurveDisplayStyleManager
chart_width, chart_height = CurveDisplayStyleManager.getChartSize()
chart_dpi = CurveDisplayStyleManager.getChartDPI()
bg_color = CurveDisplayStyleManager.getPlotBackgroundColor()
# 创建图形,使用统一的尺寸和DPI
fig, ax = plt.subplots(figsize=(chart_width/100, chart_height/100), dpi=chart_dpi)
# 绘制曲线
ax.plot(self.curve_data_x, self.curve_data_y, 'b-', linewidth=2.5, marker='o', markersize=5,
markerfacecolor='white', markeredgecolor='blue', markeredgewidth=1.5)
# 设置标签和标题
ax.set_xlabel('帧序号', fontsize=12)
ax.set_ylabel('液位高度 (mm)', fontsize=12)
ax.set_title('液位检测曲线', fontsize=14, fontweight='bold', pad=20)
# 设置网格
ax.grid(True, alpha=0.3, linestyle='--')
# 设置背景色,使用样式管理器的颜色
ax.set_facecolor(bg_color)
fig.patch.set_facecolor(bg_color)
# 优化布局,减少边距
plt.tight_layout(pad=1.0)
# 保存图片,优化参数
plt.savefig(output_path,
bbox_inches='tight',
pad_inches=0.2,
facecolor=bg_color,
edgecolor='none',
dpi=100)
plt.close(fig)
print(f"[matplotlib曲线] 曲线图片已保存: {output_path}")
return True
except Exception as e:
print(f"[matplotlib曲线] 创建失败: {e}")
import traceback
traceback.print_exc()
return False
def _createTempCurvePlot(self, output_path):
"""创建临时曲线图并保存"""
try:
print("[临时曲线] 开始创建临时曲线图")
import pyqtgraph as pg
# 验证数据
if not hasattr(self, 'curve_data_x') or not hasattr(self, 'curve_data_y'):
print("[临时曲线] 错误: 缺少曲线数据")
self._createPlaceholderImage(output_path)
return
if len(self.curve_data_x) == 0 or len(self.curve_data_y) == 0:
print(f"[临时曲线] 错误: 曲线数据为空")
self._createPlaceholderImage(output_path)
return
print(f"[临时曲线] 数据验证通过,X点数: {len(self.curve_data_x)}, Y点数: {len(self.curve_data_y)}")
# 创建临时的plot widget
temp_plot = pg.PlotWidget()
temp_plot.setBackground('#f8f9fa')
temp_plot.showGrid(x=True, y=True, alpha=0.3)
temp_plot.setLabel('left', '液位高度', units='mm')
temp_plot.setLabel('bottom', '帧序号')
temp_plot.setTitle('液位检测曲线', color='#495057', size='12pt')
print("[临时曲线] PlotWidget创建成功,开始绘制曲线")
# 绘制曲线
temp_plot.plot(
self.curve_data_x,
self.curve_data_y,
pen=pg.mkPen(color='#1f77b4', width=2),
name='液位高度'
)
print("[临时曲线] 曲线绘制完成,开始导出图片")
# 使用样式管理器的配置导出图片
from widgets.style_manager import CurveDisplayStyleManager
chart_width, chart_height = CurveDisplayStyleManager.getChartSize()
bg_color = CurveDisplayStyleManager.getPlotBackgroundColor()
# 导出图片(使用统一尺寸)
exporter = pg.exporters.ImageExporter(temp_plot.plotItem)
exporter.parameters()['width'] = chart_width
exporter.parameters()['height'] = chart_height
# 设置背景色
temp_plot.setBackground(bg_color)
exporter.export(output_path)
print(f"[临时曲线] 曲线图片已保存: {output_path}")
except Exception as e:
print(f"[临时曲线] 创建失败: {e}")
import traceback
traceback.print_exc()
# 创建一个简单的占位图片
self._createPlaceholderImage(output_path)
def _createPlaceholderImage(self, output_path):
"""创建占位图片"""
try:
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import numpy as np
# 如果有曲线数据,尝试用matplotlib绘制
if hasattr(self, 'curve_data_x') and hasattr(self, 'curve_data_y') and len(self.curve_data_x) > 0:
print("[占位图片] 尝试使用matplotlib绘制曲线")
try:
plt.figure(figsize=(10, 5), dpi=80)
plt.plot(self.curve_data_x, self.curve_data_y, 'b-', linewidth=2, marker='o', markersize=4)
plt.xlabel('帧序号')
plt.ylabel('液位高度 (mm)')
plt.title('液位检测曲线')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, facecolor='#f8f9fa')
plt.close()
print(f"[占位图片] matplotlib曲线已保存: {output_path}")
return
except Exception as e:
print(f"[占位图片] matplotlib绘制失败: {e}")
# 创建简化的占位图片(更小的尺寸,减少白色背景)
img = Image.new('RGB', (600, 300), '#f8f9fa')
draw = ImageDraw.Draw(img)
# 绘制边框
draw.rectangle([0, 0, 599, 299], outline='#dee2e6', width=1)
# 绘制文本
try:
font = ImageFont.truetype("C:/Windows/Fonts/msyh.ttc", 18)
except:
font = ImageFont.load_default()
text = "曲线生成失败,请检查数据"
text_bbox = draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
x = (600 - text_width) // 2
y = (300 - text_height) // 2
draw.text((x, y), text, fill='#666666', font=font)
img.save(output_path)
print(f"[占位图片] 已创建: {output_path}")
except Exception as e:
print(f"[占位图片] 创建失败: {e}")
def _getFallbackCurveHTML(self):
"""获取降级的曲线HTML内容"""
data_count = len(self.curve_data_x)
if data_count == 1:
liquid_level = self.curve_data_y[0]
return f"""
<div style="font-family: Arial, sans-serif; padding: 20px; background: #ffffff; color: #333333;">
<h3>图片测试结果</h3>
<p><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
<p style="color: #666; font-size: 12px;">注: 曲线图表生成失败,显示文本结果。</p>
</div>
"""
else:
min_level = min(self.curve_data_y)
max_level = max(self.curve_data_y)
avg_level = sum(self.curve_data_y) / len(self.curve_data_y)
return f"""
<div style="font-family: Arial, sans-serif; padding: 20px; background: #ffffff; color: #333333;">
<h3>视频测试结果统计</h3>
<p><strong>数据点数:</strong> {data_count} 个</p>
<p><strong>液位范围:</strong> {min_level:.1f} - {max_level:.1f} mm</p>
<p><strong>平均液位:</strong> {avg_level:.1f} mm</p>
<p style="color: #666; font-size: 12px;">注: 曲线图表生成失败,显示文本结果。</p>
</div>
"""
def _showCurveAsText(self):
"""以文本形式显示曲线结果"""
try:
fallback_html = self._getFallbackCurveHTML()
if hasattr(self, 'display_panel') and hasattr(self, 'display_layout'):
self.display_panel.setHtml(fallback_html)
self.display_layout.setCurrentWidget(self.display_panel)
print("[曲线显示] 以文本形式显示曲线结果")
except Exception as e:
print(f"[曲线显示] 文本显示也失败: {e}")
def _showNoCurveMessage(self):
"""显示无曲线数据的提示"""
QtWidgets.QMessageBox.information(
......@@ -1870,6 +2198,25 @@ class TrainingPage(QtWidgets.QWidget):
"测试完成后即可查看曲线结果。"
)
def _testCurveGeneration(self):
"""测试曲线生成功能(调试用)"""
try:
print("[曲线测试] 开始测试曲线生成功能")
# 创建测试数据
self.curve_data_x = [0, 1, 2, 3, 4]
self.curve_data_y = [25.0, 26.5, 24.8, 27.2, 25.9]
print(f"[曲线测试] 测试数据创建完成,X: {self.curve_data_x}, Y: {self.curve_data_y}")
# 测试曲线显示
self._showCurveInTestPage()
except Exception as e:
print(f"[曲线测试] 测试失败: {e}")
import traceback
traceback.print_exc()
def enableViewCurveButton(self):
"""启用查看曲线按钮(测试完成后调用)"""
try:
......@@ -1911,6 +2258,56 @@ class TrainingPage(QtWidgets.QWidget):
elif self.checkbox_template_3.isChecked():
return "template_3"
return None
def refreshModelLists(self):
"""刷新模型下拉菜单列表"""
try:
# 导入模型集页面以获取模型列表
from .modelset_page import ModelSetPage
# 获取detection_model目录下的所有模型
models = ModelSetPage.getDetectionModels()
# 刷新基础模型下拉菜单
if hasattr(self, 'base_model_combo'):
current_base = self.base_model_combo.currentText()
self.base_model_combo.clear()
for model in models:
display_name = model['name'] # 只显示模型名称,不显示文件大小
self.base_model_combo.addItem(display_name, model['path'])
# 尝试恢复之前的选择
if current_base:
index = self.base_model_combo.findText(current_base)
if index >= 0:
self.base_model_combo.setCurrentIndex(index)
# 刷新测试模型下拉菜单
if hasattr(self, 'test_model_combo'):
current_test = self.test_model_combo.currentText()
self.test_model_combo.clear()
for model in models:
display_name = model['name'] # 只显示模型名称,不显示文件大小
self.test_model_combo.addItem(display_name, model['path'])
# 尝试恢复之前的选择
if current_test:
index = self.test_model_combo.findText(current_test)
if index >= 0:
self.test_model_combo.setCurrentIndex(index)
print(f"[模型同步] 已刷新模型列表,共 {len(models)} 个模型")
except Exception as e:
print(f"[错误] 刷新模型列表失败: {e}")
import traceback
traceback.print_exc()
def refreshBaseModelList(self):
"""刷新基础模型下拉菜单(兼容旧接口)"""
self.refreshModelLists()
# 测试代码
......@@ -2069,3 +2466,4 @@ class TrainingNotesDialog(QtWidgets.QDialog):
def getNotesContent(self):
"""获取笔记内容"""
return self.text_edit.toPlainText().strip()
......@@ -941,8 +941,8 @@ class BackgroundStyleManager:
"""全局背景颜色管理器"""
# 全局背景颜色配置
GLOBAL_BACKGROUND_COLOR = "white" # 统一白色背景
GLOBAL_BACKGROUND_STYLE = "background-color: white;"
GLOBAL_BACKGROUND_COLOR = "#f8f9fa" # 统一浅灰色背景,消除纯白色
GLOBAL_BACKGROUND_STYLE = "background-color: #f8f9fa;"
@classmethod
def applyToWidget(cls, widget):
......@@ -1208,6 +1208,112 @@ def setMenuBarHoverColor(color):
# ============================================================================
# 曲线显示样式管理器 (Curve Display Style Manager)
# ============================================================================
class CurveDisplayStyleManager:
"""曲线显示样式管理器"""
# 统一的颜色配置
BACKGROUND_COLOR = "#f8f9fa" # 统一背景色
PLOT_BACKGROUND_COLOR = "#f8f9fa" # 图表背景色
CONTAINER_BACKGROUND_COLOR = "#ffffff" # 容器背景色
BORDER_COLOR = "#dee2e6" # 边框颜色
TEXT_COLOR = "#333333" # 文本颜色
# 图表尺寸配置
CHART_WIDTH = 600
CHART_HEIGHT = 350
CHART_DPI = 100
@classmethod
def getBackgroundColor(cls):
"""获取统一背景色"""
return cls.BACKGROUND_COLOR
@classmethod
def getPlotBackgroundColor(cls):
"""获取图表背景色"""
return cls.PLOT_BACKGROUND_COLOR
@classmethod
def getContainerBackgroundColor(cls):
"""获取容器背景色"""
return cls.CONTAINER_BACKGROUND_COLOR
@classmethod
def getBorderColor(cls):
"""获取边框颜色"""
return cls.BORDER_COLOR
@classmethod
def getTextColor(cls):
"""获取文本颜色"""
return cls.TEXT_COLOR
@classmethod
def getChartSize(cls):
"""获取图表尺寸"""
return cls.CHART_WIDTH, cls.CHART_HEIGHT
@classmethod
def getChartDPI(cls):
"""获取图表DPI"""
return cls.CHART_DPI
@classmethod
def generateCurveHTML(cls, curve_image_path, stats_html=""):
"""生成统一样式的曲线HTML"""
return f"""
<div style="font-family: 'Microsoft YaHei', 'SimHei', Arial, sans-serif; padding: 20px; background: {cls.BACKGROUND_COLOR}; color: {cls.TEXT_COLOR};">
<div style="margin-bottom: 20px;">
<h3 style="margin: 0 0 15px 0; color: {cls.TEXT_COLOR}; font-size: 18px; font-weight: 600;">液位检测曲线结果</h3>
</div>
{stats_html}
<div style="text-align: center; margin-bottom: 15px; background: {cls.CONTAINER_BACKGROUND_COLOR}; padding: 8px; border-radius: 6px; border: 1px solid {cls.BORDER_COLOR};">
<img src="file:///{curve_image_path.replace(chr(92), '/')}" style="max-width: 100%; height: auto; border-radius: 4px;">
</div>
<div style="padding: 10px; background: {cls.CONTAINER_BACKGROUND_COLOR}; border: 1px solid {cls.BORDER_COLOR}; border-radius: 5px; font-size: 12px; color: #666;">
<p style="margin: 0;"><strong>说明:</strong> 曲线显示了液位检测的结果变化趋势。X轴表示帧序号,Y轴表示液位高度(毫米)。</p>
</div>
</div>
"""
@classmethod
def setBackgroundColor(cls, color):
"""设置背景颜色"""
cls.BACKGROUND_COLOR = color
cls.PLOT_BACKGROUND_COLOR = color
@classmethod
def setChartSize(cls, width, height):
"""设置图表尺寸"""
cls.CHART_WIDTH = width
cls.CHART_HEIGHT = height
# 曲线显示样式管理器便捷函数
def getCurveBackgroundColor():
"""获取曲线背景色的便捷函数"""
return CurveDisplayStyleManager.getBackgroundColor()
def getCurvePlotBackgroundColor():
"""获取曲线图表背景色的便捷函数"""
return CurveDisplayStyleManager.getPlotBackgroundColor()
def getCurveChartSize():
"""获取曲线图表尺寸的便捷函数"""
return CurveDisplayStyleManager.getChartSize()
def generateCurveHTML(curve_image_path, stats_html=""):
"""生成曲线HTML的便捷函数"""
return CurveDisplayStyleManager.generateCurveHTML(curve_image_path, stats_html)
# ============================================================================
# 测试代码
# ============================================================================
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment