Commit da148b05 by yhb

123456

parent 7cf1afc7
...@@ -29,10 +29,11 @@ print("[环境变量] OpenMP冲突已修复") ...@@ -29,10 +29,11 @@ print("[环境变量] OpenMP冲突已修复")
from qtpy import QtWidgets from qtpy import QtWidgets
from app import MainWindow # 延迟导入,避免在 QApplication 创建前创建 QWidget
# from app import MainWindow
from database.config import get_config, get_temp_models_dir from database.config import get_config, get_temp_models_dir
from widgets.style_manager import FontManager # from widgets.style_manager import FontManager
from widgets.responsive_layout import ResponsiveLayout # from widgets.responsive_layout import ResponsiveLayout
def setup_logging(level: str = "info"): def setup_logging(level: str = "info"):
...@@ -270,6 +271,11 @@ def _main(): ...@@ -270,6 +271,11 @@ def _main():
app.setApplicationName('Detection') app.setApplicationName('Detection')
app.setOrganizationName('Detection') app.setOrganizationName('Detection')
# 在 QApplication 创建后导入可能创建 QWidget 的模块
from app import MainWindow
from widgets.style_manager import FontManager
from widgets.responsive_layout import ResponsiveLayout
# 初始化响应式布局系统 # 初始化响应式布局系统
ResponsiveLayout.initialize(app) ResponsiveLayout.initialize(app)
......
...@@ -686,6 +686,10 @@ class MainWindow( ...@@ -686,6 +686,10 @@ class MainWindow(
from widgets import ChannelPanel, MissionPanel, CurvePanel from widgets import ChannelPanel, MissionPanel, CurvePanel
# 主页面容器 # 主页面容器
# 修复:确保 QApplication 完全初始化后再创建 QWidget
app = QtWidgets.QApplication.instance()
if app is None:
raise RuntimeError("QApplication 未初始化")
page = QtWidgets.QWidget() page = QtWidgets.QWidget()
page_layout = QtWidgets.QVBoxLayout(page) page_layout = QtWidgets.QVBoxLayout(page)
page_layout.setContentsMargins(0, 0, 0, 0) page_layout.setContentsMargins(0, 0, 0, 0)
...@@ -714,6 +718,10 @@ class MainWindow( ...@@ -714,6 +718,10 @@ class MainWindow(
except ImportError: except ImportError:
from widgets import ChannelPanel, MissionPanel from widgets import ChannelPanel, MissionPanel
# 确保 QApplication 存在
app = QtWidgets.QApplication.instance()
if app is None:
raise RuntimeError("QApplication 未初始化")
layout_widget = QtWidgets.QWidget() layout_widget = QtWidgets.QWidget()
main_layout = QtWidgets.QHBoxLayout(layout_widget) main_layout = QtWidgets.QHBoxLayout(layout_widget)
main_layout.setContentsMargins(10, 10, 10, 10) main_layout.setContentsMargins(10, 10, 10, 10)
......
...@@ -283,15 +283,22 @@ class ModelTestHandler: ...@@ -283,15 +283,22 @@ class ModelTestHandler:
def _handleStartTestExecution(self): def _handleStartTestExecution(self):
"""执行开始测试操作 - 液位检测测试功能""" """执行开始测试操作 - 液位检测测试功能"""
try: try:
# 🔥 清空曲线数据,准备新的测试
self._clearCurve()
# 禁用查看曲线按钮
if hasattr(self.training_panel, 'disableViewCurveButton'):
self.training_panel.disableViewCurveButton()
# 切换按钮状态为“停止测试” # 切换按钮状态为“停止测试”
self.training_panel.setTestButtonState(True) self.training_panel.setTestButtonState(True)
# 获取选择的测试模型和测试文件 # 获取选择的测试模型和测试文件
test_model_display = self.training_panel.test_model_combo.currentText() test_model_display = self.training_panel.test_model_combo.currentText()
test_model_path_raw = self.training_panel.test_model_combo.currentData() test_model_path_raw = self.training_panel.test_model_combo.currentData()
# 改为从QComboBox获取数据 # 从QLineEdit获取测试文件路径(浏览选择的文件)
test_file_path_raw = self.training_panel.test_file_input.currentData() or "" test_file_path_raw = self.training_panel.test_file_input.text().strip()
test_file_display = self.training_panel.test_file_input.currentText() test_file_display = os.path.basename(test_file_path_raw) if test_file_path_raw else ""
# 关键修复:路径规范化处理,确保相对路径转换为绝对路径 # 关键修复:路径规范化处理,确保相对路径转换为绝对路径
project_root = get_project_root() project_root = get_project_root()
...@@ -342,7 +349,7 @@ class ModelTestHandler: ...@@ -342,7 +349,7 @@ class ModelTestHandler:
<p style="margin: 0; font-size: 12px; color: #ffffff;"><strong>解决方法:</strong></p> <p style="margin: 0; font-size: 12px; color: #ffffff;"><strong>解决方法:</strong></p>
<ul style="margin: 5px 0; padding-left: 20px; font-size: 12px; color: #ffffff;"> <ul style="margin: 5px 0; padding-left: 20px; font-size: 12px; color: #ffffff;">
<li>请在上方下拉框中选择测试模型</li> <li>请在上方下拉框中选择测试模型</li>
<li>请在上方下拉框中选择测试文件</li> <li>请点击"浏览..."按钮选择测试图片或视频文件</li>
<li>确保选择的文件存在且可访问</li> <li>确保选择的文件存在且可访问</li>
</ul> </ul>
</div> </div>
...@@ -353,7 +360,7 @@ class ModelTestHandler: ...@@ -353,7 +360,7 @@ class ModelTestHandler:
QtWidgets.QMessageBox.warning( QtWidgets.QMessageBox.warning(
self.training_panel, self.training_panel,
"参数缺失", "参数缺失",
error_msg + "请在上方下拉框中选择测试模型和测试文件" error_msg + '请在上方下拉框中选择测试模型,并点击"浏览..."按钮选择测试文件'
) )
return return
...@@ -542,10 +549,23 @@ class ModelTestHandler: ...@@ -542,10 +549,23 @@ class ModelTestHandler:
# 显示检测结果 # 显示检测结果
self._showDetectionResult(detection_result) self._showDetectionResult(detection_result)
# 🔥 添加曲线数据点
if 'liquid_level_mm' in detection_result:
liquid_level = detection_result['liquid_level_mm']
# 图片测试只有一个数据点,帧序号为0
self._addCurveDataPoint(0, liquid_level)
# 显示曲线面板
if hasattr(self.training_panel, 'showCurvePanel'):
self.training_panel.showCurvePanel()
# 启用查看曲线按钮
if hasattr(self.training_panel, 'enableViewCurveButton'):
self.training_panel.enableViewCurveButton()
QtWidgets.QMessageBox.information( QtWidgets.QMessageBox.information(
self.training_panel, self.training_panel,
"测试完成", "测试完成",
"模型测试已成功完成!" "模型测试已成功完成!可查看曲线分析结果。"
) )
# 恢复按钮状态 # 恢复按钮状态
...@@ -585,6 +605,29 @@ class ModelTestHandler: ...@@ -585,6 +605,29 @@ class ModelTestHandler:
self._test_thread.wait() self._test_thread.wait()
self._test_thread = None self._test_thread = None
def _addCurveDataPoint(self, frame_index, height_mm):
"""添加曲线数据点
Args:
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
try:
if hasattr(self.training_panel, 'addCurvePoint'):
self.training_panel.addCurvePoint(frame_index, height_mm)
print(f"[曲线] 添加数据点: 帧{frame_index}, 液位{height_mm:.1f}mm")
except Exception as e:
print(f"[曲线] 添加数据点失败: {e}")
def _clearCurve(self):
"""清空曲线数据"""
try:
if hasattr(self.training_panel, '_clearCurve'):
self.training_panel._clearCurve()
print(f"[曲线] 已清空曲线")
except Exception as e:
print(f"[曲线] 清空曲线失败: {e}")
def _showDetectionResult(self, detection_result): def _showDetectionResult(self, detection_result):
"""显示检测结果""" """显示检测结果"""
try: try:
...@@ -592,12 +635,16 @@ class ModelTestHandler: ...@@ -592,12 +635,16 @@ class ModelTestHandler:
# 这里可以添加结果显示逻辑 # 这里可以添加结果显示逻辑
# 例如在display_panel中显示检测结果 # 例如在display_panel中显示检测结果
if hasattr(self.training_panel, 'display_panel') and detection_result: if hasattr(self.training_panel, 'display_panel') and detection_result:
# 提取液位高度
liquid_level = detection_result.get('liquid_level_mm', 0)
result_html = f""" result_html = f"""
<div style="padding: 15px; background: #000000; border: 1px solid #28a745; border-radius: 5px; color: #ffffff;"> <div style="padding: 15px; background: #000000; border: 1px solid #28a745; border-radius: 5px; color: #ffffff;">
<h3 style="margin-top: 0; color: #28a745;">液位检测测试成功</h3> <h3 style="margin-top: 0; color: #28a745;">液位检测测试成功</h3>
<p style="color: #ffffff;"><strong>检测结果:</strong> 已完成液位检测</p> <p style="color: #ffffff;"><strong>检测结果:</strong> 已完成液位检测</p>
<p style="color: #ffffff;"><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
<div style="margin-top: 15px; padding: 10px; background: #1a1a1a; border-radius: 3px;"> <div style="margin-top: 15px; padding: 10px; background: #1a1a1a; border-radius: 3px;">
<p style="margin: 0; font-size: 12px; color: #ffffff;">检测结果已完成,可以在结果面板中查看详细信息。</p> <p style="margin: 0; font-size: 12px; color: #ffffff;">检测结果已完成,可以在曲线面板中查看详细分析。</p>
</div> </div>
</div> </div>
""" """
...@@ -1197,6 +1244,9 @@ class ModelTestHandler: ...@@ -1197,6 +1244,9 @@ class ModelTestHandler:
success_count = 0 success_count = 0
fail_count = 0 fail_count = 0
# 🔥 清空曲线数据,准备添加新的视频检测曲线
self._clearCurve()
# 关闭进度对话框 # 关闭进度对话框
if progress_dialog: if progress_dialog:
progress_dialog.setLabelText("正在检测中...") progress_dialog.setLabelText("正在检测中...")
...@@ -1245,6 +1295,13 @@ class ModelTestHandler: ...@@ -1245,6 +1295,13 @@ class ModelTestHandler:
last_detection_result = detection_result last_detection_result = detection_result
detection_count += 1 detection_count += 1
success_count += 1 success_count += 1
# 🔥 添加曲线数据点(取第一个区域的液位高度)
if detection_result and len(detection_result) > 0:
first_area_result = detection_result[0]
if 'liquid_level_mm' in first_area_result:
liquid_level = first_area_result['liquid_level_mm']
self._addCurveDataPoint(frame_index, liquid_level)
except Exception as e: except Exception as e:
print(f"[视频检测] 第 {frame_index} 帧检测失败: {e}") print(f"[视频检测] 第 {frame_index} 帧检测失败: {e}")
fail_count += 1 fail_count += 1
...@@ -1317,6 +1374,15 @@ class ModelTestHandler: ...@@ -1317,6 +1374,15 @@ class ModelTestHandler:
if not self._detection_stopped: if not self._detection_stopped:
print(f"[视频检测] 显示检测结果视频...") print(f"[视频检测] 显示检测结果视频...")
self._showDetectionVideo(output_video_path, frame_index, detection_count, success_count, fail_count) self._showDetectionVideo(output_video_path, frame_index, detection_count, success_count, fail_count)
# 🔥 显示曲线面板
if hasattr(self.training_panel, 'showCurvePanel'):
self.training_panel.showCurvePanel()
print(f"[曲线] 曲线面板已显示,共{len(self.training_panel.curve_data_x if hasattr(self.training_panel, 'curve_data_x') else [])}个数据点")
# 启用查看曲线按钮
if hasattr(self.training_panel, 'enableViewCurveButton'):
self.training_panel.enableViewCurveButton()
else: else:
print(f"[视频检测] 检测被用户停止") print(f"[视频检测] 检测被用户停止")
...@@ -1844,6 +1910,9 @@ class ModelTestHandler: ...@@ -1844,6 +1910,9 @@ class ModelTestHandler:
"生成文件:", "生成文件:",
f" 结果视频: {result_video_filename}", f" 结果视频: {result_video_filename}",
f" 测试报告: {report_filename}", f" 测试报告: {report_filename}",
f" JSON结果: {json_filename}",
f" 曲线数据: {file_prefix}_curve.csv",
f" 曲线图片: {file_prefix}_curve.png",
"", "",
"=" * 60, "=" * 60,
] ]
...@@ -1876,7 +1945,9 @@ class ModelTestHandler: ...@@ -1876,7 +1945,9 @@ class ModelTestHandler:
"files": { "files": {
"result_video": result_video_filename, "result_video": result_video_filename,
"report": report_filename, "report": report_filename,
"json_result": json_filename "json_result": json_filename,
"curve_data_csv": f"{file_prefix}_curve.csv",
"curve_image_png": f"{file_prefix}_curve.png"
} }
} }
...@@ -1885,6 +1956,23 @@ class ModelTestHandler: ...@@ -1885,6 +1956,23 @@ class ModelTestHandler:
print(f"[保存视频结果] JSON结果已保存: {json_path}") print(f"[保存视频结果] JSON结果已保存: {json_path}")
# 4. 保存曲线数据(CSV格式)和曲线图片
try:
if hasattr(self.training_panel, 'saveCurveData') and hasattr(self.training_panel, 'saveCurveImage'):
# 保存曲线CSV数据
curve_csv_filename = f"{file_prefix}_curve.csv"
curve_csv_path = os.path.join(test_results_dir, curve_csv_filename)
if self.training_panel.saveCurveData(curve_csv_path):
print(f"[保存视频结果] 曲线CSV数据已保存: {curve_csv_path}")
# 保存曲线图片
curve_image_filename = f"{file_prefix}_curve.png"
curve_image_path = os.path.join(test_results_dir, curve_image_filename)
if self.training_panel.saveCurveImage(curve_image_path):
print(f"[保存视频结果] 曲线图片已保存: {curve_image_path}")
except Exception as curve_error:
print(f"[保存视频结果] ⚠️ 曲线保存失败(非致命错误): {curve_error}")
print(f"[保存视频结果] ✅ 所有测试结果已成功保存到: {test_results_dir}") print(f"[保存视频结果] ✅ 所有测试结果已成功保存到: {test_results_dir}")
except Exception as e: except Exception as e:
...@@ -1991,6 +2079,9 @@ class ModelTestHandler: ...@@ -1991,6 +2079,9 @@ class ModelTestHandler:
f" 原始图像: {original_filename}", f" 原始图像: {original_filename}",
f" 检测结果: {result_filename}", f" 检测结果: {result_filename}",
f" 测试报告: {report_filename}", f" 测试报告: {report_filename}",
f" JSON结果: {json_filename}",
f" 曲线数据: {file_prefix}_curve.csv",
f" 曲线图片: {file_prefix}_curve.png",
"", "",
"=" * 60, "=" * 60,
] ]
...@@ -2026,7 +2117,9 @@ class ModelTestHandler: ...@@ -2026,7 +2117,9 @@ class ModelTestHandler:
"original_image": original_filename, "original_image": original_filename,
"result_image": result_filename, "result_image": result_filename,
"report": report_filename, "report": report_filename,
"json_result": json_filename "json_result": json_filename,
"curve_data_csv": f"{file_prefix}_curve.csv",
"curve_image_png": f"{file_prefix}_curve.png"
} }
} }
...@@ -2035,6 +2128,23 @@ class ModelTestHandler: ...@@ -2035,6 +2128,23 @@ class ModelTestHandler:
print(f"[保存图片结果] JSON结果已保存: {json_path}") print(f"[保存图片结果] JSON结果已保存: {json_path}")
# 5. 保存曲线数据(CSV格式)和曲线图片
try:
if hasattr(self.training_panel, 'saveCurveData') and hasattr(self.training_panel, 'saveCurveImage'):
# 保存曲线CSV数据
curve_csv_filename = f"{file_prefix}_curve.csv"
curve_csv_path = os.path.join(test_results_dir, curve_csv_filename)
if self.training_panel.saveCurveData(curve_csv_path):
print(f"[保存图片结果] 曲线CSV数据已保存: {curve_csv_path}")
# 保存曲线图片
curve_image_filename = f"{file_prefix}_curve.png"
curve_image_path = os.path.join(test_results_dir, curve_image_filename)
if self.training_panel.saveCurveImage(curve_image_path):
print(f"[保存图片结果] 曲线图片已保存: {curve_image_path}")
except Exception as curve_error:
print(f"[保存图片结果] ⚠️ 曲线保存失败(非致命错误): {curve_error}")
print(f"[保存图片结果] ✅ 所有测试结果已成功保存到: {test_results_dir}") print(f"[保存图片结果] ✅ 所有测试结果已成功保存到: {test_results_dir}")
except Exception as e: except Exception as e:
......
...@@ -165,7 +165,7 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -165,7 +165,7 @@ class ModelTrainingHandler(ModelTestHandler):
return False return False
if not training_params.get('save_liquid_data_path'): if not training_params.get('save_liquid_data_path'):
QtWidgets.QMessageBox.critical(self.main_window, "参数错误", "未找到可用的数据集配置文件") QtWidgets.QMessageBox.critical(self.main_window, "参数错误", "未选择数据集文件夹,请至少添加一个数据集文件夹")
return False return False
if not training_params.get('exp_name'): if not training_params.get('exp_name'):
...@@ -208,25 +208,44 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -208,25 +208,44 @@ class ModelTrainingHandler(ModelTestHandler):
f"基础模型文件不存在\n文件路径: {base_model}\n请检查文件路径是否正确,或重新选择模型文件。") f"基础模型文件不存在\n文件路径: {base_model}\n请检查文件路径是否正确,或重新选择模型文件。")
return False return False
if not os.path.exists(save_liquid_data_path): # 解析数据集文件夹列表(用分号分隔)
QtWidgets.QMessageBox.critical(self.main_window, "文件错误", dataset_folders = [f.strip() for f in save_liquid_data_path.split(';') if f.strip()]
f"数据集配置文件不存在\n文件路径: {save_liquid_data_path}\n请检查文件路径是否正确,或重新选择数据集文件。")
if not dataset_folders:
QtWidgets.QMessageBox.critical(self.main_window, "参数错误", "未选择数据集文件夹,请至少添加一个数据集文件夹")
return False
# 验证每个数据集文件夹是否存在
invalid_folders = []
for folder in dataset_folders:
if not os.path.exists(folder):
invalid_folders.append(folder)
if invalid_folders:
QtWidgets.QMessageBox.critical(
self.main_window,
"文件夹错误",
f"以下数据集文件夹不存在:\n\n" + "\n".join(invalid_folders) +
"\n\n请检查文件夹路径是否正确,或重新选择数据集文件夹。"
)
return False return False
# 验证数据集配置和内容 # 验证数据集文件夹内容
validation_result, validation_msg = self._validateTrainingDataWithDetails(save_liquid_data_path) validation_result, validation_msg = self._validateDatasetFolders(dataset_folders)
if not validation_result: if not validation_result:
QtWidgets.QMessageBox.critical( QtWidgets.QMessageBox.critical(
self.main_window, self.main_window,
"数据集验证失败", "数据集验证失败",
f"数据集验证失败:\n\n{validation_msg}\n\n请检查数据集配置和文件。" f"数据集验证失败:\n\n{validation_msg}\n\n请检查数据集文件夹内容。"
) )
return False return False
# 确认对话框 # 确认对话框
confirm_msg = f"确定要开始升级模型吗?\n\n" confirm_msg = f"确定要开始升级模型吗?\n\n"
confirm_msg += f"基础模型: {os.path.basename(base_model)}\n" confirm_msg += f"基础模型: {os.path.basename(base_model)}\n"
confirm_msg += f"数据集: {os.path.basename(save_liquid_data_path)}\n" confirm_msg += f"数据集文件夹数量: {len(dataset_folders)}\n"
for i, folder in enumerate(dataset_folders, 1):
confirm_msg += f" {i}. {os.path.basename(folder)}\n"
confirm_msg += f"图像尺寸: {training_params['imgsz']}\n" confirm_msg += f"图像尺寸: {training_params['imgsz']}\n"
confirm_msg += f"训练轮数: {training_params['epochs']}\n" confirm_msg += f"训练轮数: {training_params['epochs']}\n"
confirm_msg += f"批次大小: {training_params['batch']}\n" confirm_msg += f"批次大小: {training_params['batch']}\n"
...@@ -271,12 +290,22 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -271,12 +290,22 @@ class ModelTrainingHandler(ModelTestHandler):
def _startTrainingWorker(self, training_params): def _startTrainingWorker(self, training_params):
"""启动训练工作线程""" """启动训练工作线程"""
try: try:
# 检查是否已有训练在进行中
if self.training_active and self.training_worker:
QtWidgets.QMessageBox.warning(
self.main_window,
"提示",
"训练正在进行中,请先停止当前训练"
)
return False
# 禁止自动下载yolo11模型 # 禁止自动下载yolo11模型
os.environ['YOLO_AUTODOWNLOAD'] = '0' os.environ['YOLO_AUTODOWNLOAD'] = '0'
os.environ['YOLO_OFFLINE'] = '1' os.environ['YOLO_OFFLINE'] = '1'
# 重置用户停止标记 # 重置用户停止标记
self._is_user_stopped = False self._is_user_stopped = False
self._is_stopping = False # 标记训练是否正在停止中
# 如果面板处于"继续训练"模式,切换回"停止升级"模式 # 如果面板处于"继续训练"模式,切换回"停止升级"模式
if hasattr(self, 'training_panel'): if hasattr(self, 'training_panel'):
...@@ -288,6 +317,31 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -288,6 +317,31 @@ class ModelTrainingHandler(ModelTestHandler):
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_log_text'): if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_log_text'):
self.training_panel.train_log_text.clear() self.training_panel.train_log_text.clear()
# 处理多数据集文件夹合并
dataset_folders = [f.strip() for f in training_params['save_liquid_data_path'].split(';') if f.strip()]
if len(dataset_folders) > 1:
self._appendLog("检测到多个数据集文件夹,正在合并...\n")
merged_data_yaml = self._mergeMultipleDatasets(dataset_folders, training_params['exp_name'])
if merged_data_yaml:
training_params['save_liquid_data_path'] = merged_data_yaml
self._appendLog(f"数据集合并完成: {merged_data_yaml}\n")
else:
QtWidgets.QMessageBox.critical(self.main_window, "错误", "数据集合并失败")
return False
elif len(dataset_folders) == 1:
# 单个数据集文件夹,需要创建data.yaml文件
single_folder = dataset_folders[0]
data_yaml_path = self._createDataYamlForSingleFolder(single_folder, training_params['exp_name'])
if data_yaml_path:
training_params['save_liquid_data_path'] = data_yaml_path
self._appendLog(f"已为单个数据集创建配置文件: {data_yaml_path}\n")
else:
QtWidgets.QMessageBox.critical(self.main_window, "错误", "创建数据集配置文件失败")
return False
# 禁用笔记保存和提交按钮(训练开始时)
self._disableNotesButtons()
# 更新UI状态 # 更新UI状态
if hasattr(self, 'training_panel'): if hasattr(self, 'training_panel'):
if hasattr(self.training_panel, 'train_status_label'): if hasattr(self.training_panel, 'train_status_label'):
...@@ -345,11 +399,54 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -345,11 +399,54 @@ class ModelTrainingHandler(ModelTestHandler):
return False return False
def _onStopTraining(self): def _onStopTraining(self):
"""停止训练 - 优雅停止,完成当前epoch后停止""" """停止训练 - 根据训练状态采用不同策略"""
# 检查是否已经在停止过程中
if getattr(self, '_is_stopping', False):
self._appendLog("\n[提示] 训练正在停止中,请耐心等待...\n")
return True
if self.training_worker and self.training_active: if self.training_worker and self.training_active:
# 检查训练是否已经真正开始
training_started = self.training_worker.has_training_started()
# 设置停止标记,防止重复触发
self._is_stopping = True
self._is_user_stopped = True # 标记为用户手动停止 self._is_user_stopped = True # 标记为用户手动停止
self.training_worker.stop_training() # 设置 is_running = False,YOLO会在epoch结束时检查 self.training_worker.stop_training() # 设置 is_running = False
if not training_started:
# 训练还未真正开始(仍在初始化阶段),直接取消训练
self._appendLog("\n" + "="*70 + "\n")
self._appendLog("训练尚未开始,正在取消训练...\n")
self._appendLog("="*70 + "\n")
# 更新状态标签
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_status_label'):
self.training_panel.train_status_label.setText("正在取消训练...")
self.training_panel.train_status_label.setStyleSheet("""
QLabel {
color: #ffffff;
background-color: #dc3545;
border: 1px solid #dc3545;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
""")
FontManager.applyToWidget(self.training_panel.train_status_label, weight=FontManager.WEIGHT_BOLD)
# 强制终止训练线程(因为训练还未开始,可以安全终止)
if self.training_worker:
self.training_worker.terminate()
self.training_worker.wait(3000) # 等待最多3秒
if self.training_worker.isRunning():
self.training_worker.kill() # 强制杀死线程
# 直接调用训练完成回调,恢复UI状态
self._onTrainingFinished(False)
else:
# 训练已经开始,优雅停止(完成当前epoch后停止)
self._appendLog("\n" + "="*70 + "\n") self._appendLog("\n" + "="*70 + "\n")
self._appendLog("用户请求停止训练\n") self._appendLog("用户请求停止训练\n")
self._appendLog("正在完成当前训练轮次...\n") self._appendLog("正在完成当前训练轮次...\n")
...@@ -371,9 +468,12 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -371,9 +468,12 @@ class ModelTrainingHandler(ModelTestHandler):
""") """)
FontManager.applyToWidget(self.training_panel.train_status_label, weight=FontManager.WEIGHT_BOLD) FontManager.applyToWidget(self.training_panel.train_status_label, weight=FontManager.WEIGHT_BOLD)
# 禁用停止按钮,防止重复点击 # 禁用所有训练相关按钮,防止重复点击和冲突
if hasattr(self, 'training_panel'): if hasattr(self, 'training_panel'):
if hasattr(self.training_panel, 'stop_train_btn'):
self.training_panel.stop_train_btn.setEnabled(False) self.training_panel.stop_train_btn.setEnabled(False)
if hasattr(self.training_panel, 'start_train_btn'):
self.training_panel.start_train_btn.setEnabled(False)
# 不立刻终止线程,让YOLO在epoch结束时自动停止 # 不立刻终止线程,让YOLO在epoch结束时自动停止
# 线程会在 _onTrainingFinished 中被清理 # 线程会在 _onTrainingFinished 中被清理
...@@ -386,7 +486,8 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -386,7 +486,8 @@ class ModelTrainingHandler(ModelTestHandler):
def _onTrainingFinished(self, success): def _onTrainingFinished(self, success):
"""训练完成回调""" """训练完成回调"""
try: try:
# 重置停止标记
self._is_stopping = False
self.training_active = False self.training_active = False
if success: if success:
...@@ -478,21 +579,30 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -478,21 +579,30 @@ class ModelTrainingHandler(ModelTestHandler):
self._appendLog(" 训练完成通知\n") self._appendLog(" 训练完成通知\n")
self._appendLog("="*70 + "\n") self._appendLog("="*70 + "\n")
self._appendLog("模型升级已完成!\n") self._appendLog("模型升级已完成!\n")
self._appendLog("新模型已保存到detection_model目录\n")
self._appendLog("新模型已自动添加到模型集管理\n") self._appendLog("新模型已自动添加到模型集管理\n")
self._appendLog("请切换到【模型集管理】页面查看新模型\n") self._appendLog("请切换到【模型集管理】页面查看新模型\n")
self._appendLog("="*70 + "\n") self._appendLog("="*70 + "\n")
# 启用笔记保存和提交按钮(训练完成后允许继续编辑笔记)
self._enableNotesButtons()
# 使用定时器延迟显示消息框,避免阻塞训练线程的清理 # 使用定时器延迟显示消息框,避免阻塞训练线程的清理
QtCore.QTimer.singleShot(500, lambda: QtWidgets.QMessageBox.information( QtCore.QTimer.singleShot(500, lambda: QtWidgets.QMessageBox.information(
self.main_window, self.main_window,
"升级完成", "升级完成",
"模型升级已完成!\n新模型已自动添加到模型集管理。" "模型升级已完成!\n新模型已保存到detection_model目录\n自动添加到模型集管理。"
)) ))
else: else:
# 检查是否为用户手动停止 # 检查是否为用户手动停止
is_user_stopped = getattr(self, '_is_user_stopped', False) is_user_stopped = getattr(self, '_is_user_stopped', False)
if is_user_stopped: # 检查训练是否已经真正开始
training_started = False
if self.training_worker:
training_started = self.training_worker.has_training_started()
if is_user_stopped and training_started:
self._appendLog("\n" + "="*70 + "\n") self._appendLog("\n" + "="*70 + "\n")
self._appendLog("训练已暂停\n") self._appendLog("训练已暂停\n")
self._appendLog("="*70 + "\n") self._appendLog("="*70 + "\n")
...@@ -578,6 +688,39 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -578,6 +688,39 @@ class ModelTrainingHandler(ModelTestHandler):
# 重置标记 # 重置标记
self._is_user_stopped = False self._is_user_stopped = False
self._is_stopping = False
elif is_user_stopped and not training_started:
# 训练被取消(未真正开始),恢复初始状态
self._appendLog("\n" + "="*70 + "\n")
self._appendLog("训练已取消\n")
self._appendLog("="*70 + "\n")
# 更新状态标签
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_status_label'):
self.training_panel.train_status_label.setText("训练已取消")
self.training_panel.train_status_label.setStyleSheet("""
QLabel {
color: #ffffff;
background-color: #6c757d;
border: 1px solid #6c757d;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
""")
FontManager.applyToWidget(self.training_panel.train_status_label, weight=FontManager.WEIGHT_BOLD)
# 恢复按钮状态(允许重新开始训练)
if hasattr(self, 'training_panel'):
if hasattr(self.training_panel, 'start_train_btn'):
self.training_panel.start_train_btn.setEnabled(True)
if hasattr(self.training_panel, 'stop_train_btn'):
self.training_panel.stop_train_btn.setEnabled(False)
self.training_panel.stop_train_btn.setText("停止升级") # 恢复原始文本
# 重置标记
self._is_user_stopped = False
self._is_stopping = False
else: else:
self._appendLog("\n" + "="*70 + "\n") self._appendLog("\n" + "="*70 + "\n")
self._appendLog(" 升级失败\n") self._appendLog(" 升级失败\n")
...@@ -610,6 +753,10 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -610,6 +753,10 @@ class ModelTrainingHandler(ModelTestHandler):
# 重置训练停止标记 # 重置训练停止标记
self.training_panel._is_training_stopped = False self.training_panel._is_training_stopped = False
# 重置停止标记
self._is_user_stopped = False
self._is_stopping = False
# 如果是正常完成(非用户停止),恢复按钮状态 # 如果是正常完成(非用户停止),恢复按钮状态
if success and not self._is_user_stopped: if success and not self._is_user_stopped:
if hasattr(self, 'training_panel'): if hasattr(self, 'training_panel'):
...@@ -891,41 +1038,19 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -891,41 +1038,19 @@ class ModelTrainingHandler(ModelTestHandler):
def _initializeTrainingPanelDefaults(self, training_panel): def _initializeTrainingPanelDefaults(self, training_panel):
"""初始化训练面板的默认值""" """初始化训练面板的默认值"""
try: try:
# 设置默认模型路径 # 基础模型现在通过下拉菜单从模型集管理页面加载,不需要手动设置
if hasattr(training_panel, 'base_model_edit'): # 下拉菜单会在页面显示时自动加载并选择默认模型
project_root = get_project_root()
available_models = [
os.path.join(project_root, 'database', 'model', 'train_model', '2', 'best.pt'),
os.path.join(project_root, 'database', 'model', 'train_model')
]
for model_path in available_models:
if os.path.exists(model_path):
if os.path.isfile(model_path):
training_panel.base_model_edit.setText(model_path)
break
elif os.path.isdir(model_path):
# 查找目录中的第一个模型文件
for root, dirs, files in os.walk(model_path):
for file in files:
if file.endswith('.pt') or file.endswith('.dat'):
full_path = os.path.join(root, file)
training_panel.base_model_edit.setText(full_path)
break
if training_panel.base_model_edit.text():
break
if training_panel.base_model_edit.text():
break
# 设置默认数据集路径 # 数据集现在通过文件夹列表管理,可以手动添加默认数据集文件夹
if hasattr(training_panel, 'save_liquid_data_path_edit'): if hasattr(training_panel, 'dataset_folders_list'):
project_root = get_project_root() project_root = get_project_root()
available_datasets = [ default_dataset_folder = os.path.join(project_root, 'database', 'dataset')
os.path.join(project_root, 'database', 'dataset', 'data.yaml'), if os.path.exists(default_dataset_folder) and os.path.isdir(default_dataset_folder):
] # 添加默认数据集文件夹(如果列表为空)
for dataset_path in available_datasets: if training_panel.dataset_folders_list.count() == 0:
if os.path.exists(dataset_path): training_panel.dataset_folders_list.addItem(default_dataset_folder)
training_panel.save_liquid_data_path_edit.setText(dataset_path) if hasattr(training_panel, '_updateDatasetPath'):
break training_panel._updateDatasetPath()
# 设置默认模型名称 # 设置默认模型名称
if hasattr(training_panel, 'exp_name_edit'): if hasattr(training_panel, 'exp_name_edit'):
...@@ -1174,7 +1299,7 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -1174,7 +1299,7 @@ class ModelTrainingHandler(ModelTestHandler):
device_value = device_text device_value = device_text
training_params = { training_params = {
'base_model': getattr(panel, 'base_model_edit', None) and panel.base_model_edit.text() or '', 'base_model': getattr(panel, 'base_model_combo', None) and panel.base_model_combo.currentData() or '',
'save_liquid_data_path': getattr(panel, 'save_liquid_data_path_edit', None) and panel.save_liquid_data_path_edit.text() or '', 'save_liquid_data_path': getattr(panel, 'save_liquid_data_path_edit', None) and panel.save_liquid_data_path_edit.text() or '',
'imgsz': getattr(panel, 'imgsz_spin', None) and panel.imgsz_spin.value() or 640, 'imgsz': getattr(panel, 'imgsz_spin', None) and panel.imgsz_spin.value() or 640,
'epochs': getattr(panel, 'epochs_spin', None) and panel.epochs_spin.value() or 100, 'epochs': getattr(panel, 'epochs_spin', None) and panel.epochs_spin.value() or 100,
...@@ -1241,34 +1366,39 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -1241,34 +1366,39 @@ class ModelTrainingHandler(ModelTestHandler):
# 路径有效,更新为绝对路径 # 路径有效,更新为绝对路径
training_params['base_model'] = base_model training_params['base_model'] = base_model
# 修正数据集路径 # 修正数据集路径(支持多文件夹,用分号分隔)
save_liquid_data_path = training_params.get('save_liquid_data_path', '') save_liquid_data_path = training_params.get('save_liquid_data_path', '')
if save_liquid_data_path:
# 解析多个文件夹路径
folders = [f.strip() for f in save_liquid_data_path.split(';') if f.strip()]
fixed_folders = []
for folder in folders:
# 如果是相对路径,转换为绝对路径 # 如果是相对路径,转换为绝对路径
if save_liquid_data_path and not os.path.isabs(save_liquid_data_path): if not os.path.isabs(folder):
project_root = get_project_root() project_root = get_project_root()
save_liquid_data_path = os.path.join(project_root, save_liquid_data_path) folder = os.path.join(project_root, folder)
if not save_liquid_data_path or not os.path.exists(save_liquid_data_path): # 只保留存在的文件夹
# 尝试使用配置文件中的默认路径 if os.path.exists(folder) and os.path.isdir(folder):
if self.train_config and 'default_parameters' in self.train_config: fixed_folders.append(folder)
default_data = self.train_config['default_parameters'].get('dataset_path', '')
if default_data and os.path.exists(default_data): # 更新为绝对路径列表
training_params['save_liquid_data_path'] = default_data if fixed_folders:
training_params['save_liquid_data_path'] = ';'.join(fixed_folders)
else: else:
# 使用项目中可用的数据集 # 如果没有有效文件夹,尝试使用默认数据集文件夹
project_root = get_project_root() project_root = get_project_root()
available_datasets = [ default_folder = os.path.join(project_root, 'database', 'dataset')
os.path.join(project_root, 'database', 'dataset', 'data_template_1.yaml'), if os.path.exists(default_folder) and os.path.isdir(default_folder):
os.path.join(project_root, 'database', 'dataset', 'data.yaml'), training_params['save_liquid_data_path'] = default_folder
]
for dataset_path in available_datasets:
if os.path.exists(dataset_path):
training_params['save_liquid_data_path'] = dataset_path
break
else: else:
# 路径有效,更新为绝对路径 # 没有指定数据集,使用默认数据集文件夹
training_params['save_liquid_data_path'] = save_liquid_data_path project_root = get_project_root()
default_folder = os.path.join(project_root, 'database', 'dataset')
if os.path.exists(default_folder) and os.path.isdir(default_folder):
training_params['save_liquid_data_path'] = default_folder
return training_params return training_params
...@@ -1360,6 +1490,78 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -1360,6 +1490,78 @@ class ModelTrainingHandler(ModelTestHandler):
except Exception as e: except Exception as e:
return False, f"验证过程出错: {str(e)}" return False, f"验证过程出错: {str(e)}"
def _validateDatasetFolders(self, dataset_folders):
"""
验证多个数据集文件夹
Args:
dataset_folders: 数据集文件夹路径列表
Returns:
tuple: (是否有效, 错误消息)
"""
try:
if not dataset_folders:
return False, "数据集文件夹列表为空"
# 检查每个文件夹
total_train_images = 0
total_val_images = 0
total_labels = 0
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']
for folder in dataset_folders:
if not os.path.exists(folder):
return False, f"文件夹不存在: {folder}"
if not os.path.isdir(folder):
return False, f"路径不是文件夹: {folder}"
# 检查文件夹结构(YOLO格式)
# 期望结构: folder/images/train, folder/images/val, folder/labels/train, folder/labels/val
images_train_dir = os.path.join(folder, 'images', 'train')
images_val_dir = os.path.join(folder, 'images', 'val')
labels_train_dir = os.path.join(folder, 'labels', 'train')
labels_val_dir = os.path.join(folder, 'labels', 'val')
# 检查是否存在训练图片目录
if os.path.exists(images_train_dir):
train_count = sum(1 for f in os.listdir(images_train_dir)
if any(f.lower().endswith(ext) for ext in image_extensions))
total_train_images += train_count
# 检查是否存在验证图片目录
if os.path.exists(images_val_dir):
val_count = sum(1 for f in os.listdir(images_val_dir)
if any(f.lower().endswith(ext) for ext in image_extensions))
total_val_images += val_count
# 检查标签文件
if os.path.exists(labels_train_dir):
label_count = sum(1 for f in os.listdir(labels_train_dir)
if f.lower().endswith('.txt'))
total_labels += label_count
# 验证是否有足够的数据
if total_train_images == 0 and total_val_images == 0:
return False, f"所有数据集文件夹中都没有找到图片文件\n请确保文件夹包含 images/train 或 images/val 子目录"
if total_train_images == 0:
return False, f"未找到训练图片\n请确保至少一个文件夹包含 images/train 子目录及图片文件"
# 返回验证结果
msg = f"数据集验证通过\n"
msg += f"总训练图片: {total_train_images} 张\n"
if total_val_images > 0:
msg += f"总验证图片: {total_val_images} 张\n"
if total_labels > 0:
msg += f"总标注文件: {total_labels} 个"
return True, msg
except Exception as e:
return False, f"验证过程出错: {str(e)}"
def _validateTrainingData(self, save_liquid_data_path): def _validateTrainingData(self, save_liquid_data_path):
"""验证训练数据(简化版,用于向后兼容)""" """验证训练数据(简化版,用于向后兼容)"""
result, _ = self._validateTrainingDataWithDetails(save_liquid_data_path) result, _ = self._validateTrainingDataWithDetails(save_liquid_data_path)
...@@ -1554,15 +1756,15 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -1554,15 +1756,15 @@ class ModelTrainingHandler(ModelTestHandler):
) )
return return
# 改为从QComboBox获取数据 # 从QLineEdit获取测试文件路径(浏览选择的文件)
test_file_path = self.training_panel.test_file_input.currentData() or "" test_file_path = self.training_panel.test_file_input.text().strip()
test_file_display = self.training_panel.test_file_input.currentText() test_file_display = os.path.basename(test_file_path) if test_file_path else ""
if not test_file_path: if not test_file_path:
QtWidgets.QMessageBox.warning( QtWidgets.QMessageBox.warning(
self.training_panel, self.training_panel,
"提示", "提示",
"请先选择测试文件" '请点击"浏览..."按钮选择测试文件'
) )
return return
...@@ -2183,11 +2385,21 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -2183,11 +2385,21 @@ class ModelTrainingHandler(ModelTestHandler):
'training_date': self._getCurrentTimestamp() 'training_date': self._getCurrentTimestamp()
} }
# 获取训练笔记
training_notes = self._getTrainingNotes()
# 将模型移动到detection_model目录
detection_model_path = self._moveModelToDetectionDir(best_model_path, model_name, weights_dir, training_notes)
if detection_model_path:
# 更新模型参数中的路径
model_params['path'] = detection_model_path
self._appendLog(f"模型已移动到detection_model目录: {os.path.basename(detection_model_path)}\n")
# 保存到模型集配置 # 保存到模型集配置
self._saveModelToConfig(model_name, model_params) self._saveModelToConfig(model_name, model_params)
self._appendLog(f"新模型已添加到模型集: {model_name}\n") self._appendLog(f"新模型已添加到模型集: {model_name}\n")
self._appendLog(f" 路径: {best_model_path}\n") self._appendLog(f" 路径: {model_params['path']}\n")
self._appendLog(f" 大小: {model_size}\n") self._appendLog(f" 大小: {model_size}\n")
except Exception as e: except Exception as e:
...@@ -2195,8 +2407,216 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -2195,8 +2407,216 @@ class ModelTrainingHandler(ModelTestHandler):
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def _getTrainingNotes(self):
"""获取训练页面的笔记内容"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'getTrainingNotes'):
notes = self.training_panel.getTrainingNotes()
if notes:
self._appendLog(f"[笔记] 获取到训练笔记,长度: {len(notes)} 字符\n")
return notes
else:
self._appendLog("[笔记] 未输入训练笔记\n")
return ""
else:
self._appendLog("[笔记] 无法获取训练页面笔记接口\n")
return ""
except Exception as e:
self._appendLog(f"[ERROR] 获取训练笔记失败: {str(e)}\n")
return ""
def _clearTrainingNotes(self):
"""清空训练页面的笔记内容"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'clearTrainingNotes'):
self.training_panel.clearTrainingNotes()
self._appendLog("[笔记] 训练笔记已清空\n")
else:
self._appendLog("[笔记] 无法获取训练页面笔记清空接口\n")
except Exception as e:
self._appendLog(f"[ERROR] 清空训练笔记失败: {str(e)}\n")
def _enableNotesButtons(self):
"""启用训练页面的笔记保存和提交按钮"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'enableNotesButtons'):
self.training_panel.enableNotesButtons()
self._appendLog("[笔记] 笔记保存和提交按钮已启用\n")
else:
self._appendLog("[笔记] 无法获取训练页面笔记按钮接口\n")
except Exception as e:
self._appendLog(f"[ERROR] 启用笔记按钮失败: {str(e)}\n")
def _disableNotesButtons(self):
"""禁用训练页面的笔记保存和提交按钮"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'disableNotesButtons'):
self.training_panel.disableNotesButtons()
self._appendLog("[笔记] 笔记保存和提交按钮已禁用\n")
else:
self._appendLog("[笔记] 无法获取训练页面笔记按钮接口\n")
except Exception as e:
self._appendLog(f"[ERROR] 禁用笔记按钮失败: {str(e)}\n")
def saveNotesToLatestModel(self, notes):
"""保存笔记到最新训练的模型目录"""
try:
if not notes or not notes.strip():
self._appendLog("[笔记] 笔记内容为空,无需保存\n")
return False
# 获取最新的模型目录
latest_model_dir = self._getLatestModelDirectory()
if not latest_model_dir:
self._appendLog("[ERROR] 无法找到最新的模型目录\n")
return False
# 保存笔记文件
notes_file = os.path.join(latest_model_dir, 'training_notes.txt')
try:
with open(notes_file, 'w', encoding='utf-8') as f:
# 添加时间戳和更新信息
f.write(f"训练笔记 - {os.path.basename(latest_model_dir)}\n")
f.write(f"最后更新: {self._getCurrentTimestamp()}\n")
f.write("="*50 + "\n\n")
f.write(notes)
self._appendLog(f"[笔记] 笔记已保存到: {notes_file}\n")
return True
except Exception as e:
self._appendLog(f"[ERROR] 保存笔记文件失败: {str(e)}\n")
return False
except Exception as e:
self._appendLog(f"[ERROR] 保存笔记到最新模型失败: {str(e)}\n")
import traceback
traceback.print_exc()
return False
def _getLatestModelDirectory(self):
"""获取最新的detection_model目录"""
try:
project_root = get_project_root()
detection_model_dir = os.path.join(project_root, 'database', 'model', 'detection_model')
if not os.path.exists(detection_model_dir):
return None
# 获取所有数字目录
digit_dirs = []
for item in os.listdir(detection_model_dir):
item_path = os.path.join(detection_model_dir, item)
if os.path.isdir(item_path) and item.isdigit():
digit_dirs.append(int(item))
if not digit_dirs:
return None
# 返回最大数字的目录
latest_id = max(digit_dirs)
latest_dir = os.path.join(detection_model_dir, str(latest_id))
self._appendLog(f"[笔记] 找到最新模型目录: detection_model/{latest_id}\n")
return latest_dir
except Exception as e:
self._appendLog(f"[ERROR] 获取最新模型目录失败: {str(e)}\n")
return None
def _moveModelToDetectionDir(self, model_path, model_name, weights_dir, training_notes=""):
"""将训练完成的模型移动到detection_model目录"""
try:
import shutil
from pathlib import Path
self._appendLog(f"\n开始移动模型到detection_model目录...\n")
# 获取项目根目录
project_root = get_project_root()
detection_model_dir = os.path.join(project_root, 'database', 'model', 'detection_model')
# 确保detection_model目录存在
os.makedirs(detection_model_dir, exist_ok=True)
# 获取下一个可用的数字ID
existing_dirs = []
for item in os.listdir(detection_model_dir):
item_path = os.path.join(detection_model_dir, item)
if os.path.isdir(item_path) and item.isdigit():
existing_dirs.append(int(item))
next_id = max(existing_dirs) + 1 if existing_dirs else 1
target_model_dir = os.path.join(detection_model_dir, str(next_id))
self._appendLog(f" 目标目录: {target_model_dir}\n")
# 创建目标目录结构
os.makedirs(target_model_dir, exist_ok=True)
target_weights_dir = os.path.join(target_model_dir, 'weights')
os.makedirs(target_weights_dir, exist_ok=True)
# 移动整个weights目录的内容
if weights_dir and os.path.exists(weights_dir):
self._appendLog(f" 复制weights目录内容...\n")
# 复制所有文件到目标weights目录
for filename in os.listdir(weights_dir):
source_file = os.path.join(weights_dir, filename)
target_file = os.path.join(target_weights_dir, filename)
if os.path.isfile(source_file):
shutil.copy2(source_file, target_file)
self._appendLog(f" 复制: {filename}\n")
# 复制训练目录的其他文件(如config.yaml, results.csv等)
train_exp_dir = os.path.dirname(weights_dir)
if os.path.exists(train_exp_dir):
for filename in os.listdir(train_exp_dir):
if filename != 'weights': # 跳过weights目录
source_file = os.path.join(train_exp_dir, filename)
target_file = os.path.join(target_model_dir, filename)
if os.path.isfile(source_file):
shutil.copy2(source_file, target_file)
self._appendLog(f" 复制配置: {filename}\n")
# 保存训练笔记(如果有)
if training_notes:
notes_file = os.path.join(target_model_dir, 'training_notes.txt')
try:
with open(notes_file, 'w', encoding='utf-8') as f:
# 添加时间戳和模型信息
f.write(f"训练笔记 - {model_name}\n")
f.write(f"训练时间: {self._getCurrentTimestamp()}\n")
f.write(f"模型ID: {next_id}\n")
f.write("="*50 + "\n\n")
f.write(training_notes)
self._appendLog(f" 保存训练笔记: training_notes.txt\n")
except Exception as e:
self._appendLog(f" 保存训练笔记失败: {str(e)}\n")
# 确定最终的模型文件路径
model_filename = os.path.basename(model_path)
final_model_path = os.path.join(target_weights_dir, model_filename)
if os.path.exists(final_model_path):
self._appendLog(f"✅ 模型已成功移动到detection_model/{next_id}/weights/{model_filename}\n")
if training_notes:
self._appendLog(f"✅ 训练笔记已保存到detection_model/{next_id}/training_notes.txt\n")
return final_model_path
else:
self._appendLog(f"❌ 模型移动失败,目标文件不存在: {final_model_path}\n")
return None
except Exception as e:
self._appendLog(f"❌ [ERROR] 移动模型到detection_model失败: {str(e)}\n")
import traceback
traceback.print_exc()
return None
def _saveModelToConfig(self, model_name, model_params): def _saveModelToConfig(self, model_name, model_params):
"""保存模型配置文件(模型已经在train_model目录中)""" """保存模型配置文件(模型已经在detection_model目录中)"""
try: try:
from pathlib import Path from pathlib import Path
import yaml import yaml
...@@ -2205,7 +2625,7 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -2205,7 +2625,7 @@ class ModelTrainingHandler(ModelTestHandler):
self._appendLog(f" 模型名称: {model_name}\n") self._appendLog(f" 模型名称: {model_name}\n")
self._appendLog(f" 模型路径: {model_params['path']}\n") self._appendLog(f" 模型路径: {model_params['path']}\n")
# 获取模型所在目录(应该已经在train_model/{数字ID}/weights/中) # 获取模型所在目录(现在应该在detection_model/{数字ID}/weights/中)
model_path = model_params['path'] model_path = model_params['path']
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}") raise FileNotFoundError(f"模型文件不存在: {model_path}")
...@@ -2254,7 +2674,8 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -2254,7 +2674,8 @@ class ModelTrainingHandler(ModelTestHandler):
yaml.dump(model_params, f, allow_unicode=True, default_flow_style=False) yaml.dump(model_params, f, allow_unicode=True, default_flow_style=False)
# 输出总结信息 # 输出总结信息
self._appendLog(f"\n✅ 模型配置已保存: {model_dir}\n") model_id = os.path.basename(model_dir)
self._appendLog(f"\n✅ 模型配置已保存到detection_model/{model_id}/\n")
self._appendLog(f" 找到的文件:\n") self._appendLog(f" 找到的文件:\n")
for info in model_files_found: for info in model_files_found:
self._appendLog(f" - {info}\n") self._appendLog(f" - {info}\n")
...@@ -2262,6 +2683,7 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -2262,6 +2683,7 @@ class ModelTrainingHandler(ModelTestHandler):
for log_file in log_files_found: for log_file in log_files_found:
self._appendLog(f" - 训练日志: {log_file}\n") self._appendLog(f" - 训练日志: {log_file}\n")
self._appendLog(f" - 配置文件: config.yaml\n") self._appendLog(f" - 配置文件: config.yaml\n")
self._appendLog(f"\n🎉 模型已成功保存到detection_model目录,可在模型集管理中查看!\n")
except Exception as e: except Exception as e:
self._appendLog(f"❌ [ERROR] 保存模型配置失败: {str(e)}\n") self._appendLog(f"❌ [ERROR] 保存模型配置失败: {str(e)}\n")
...@@ -2652,3 +3074,178 @@ class ModelTrainingHandler(ModelTestHandler): ...@@ -2652,3 +3074,178 @@ class ModelTrainingHandler(ModelTestHandler):
return "未知" return "未知"
except: except:
return "未知" return "未知"
def _mergeMultipleDatasets(self, dataset_folders, exp_name):
"""
合并多个数据集文件夹为统一的训练配置
Args:
dataset_folders: 数据集文件夹路径列表
exp_name: 实验名称
Returns:
str: 合并后的data.yaml文件路径,失败返回None
"""
try:
import shutil
import tempfile
from pathlib import Path
# 创建临时合并目录
project_root = get_project_root()
temp_dataset_dir = os.path.join(project_root, 'database', 'temp_datasets', exp_name)
# 如果目录已存在,先删除
if os.path.exists(temp_dataset_dir):
shutil.rmtree(temp_dataset_dir)
# 创建合并后的目录结构
merged_images_train = os.path.join(temp_dataset_dir, 'images', 'train')
merged_images_val = os.path.join(temp_dataset_dir, 'images', 'val')
merged_labels_train = os.path.join(temp_dataset_dir, 'labels', 'train')
merged_labels_val = os.path.join(temp_dataset_dir, 'labels', 'val')
os.makedirs(merged_images_train, exist_ok=True)
os.makedirs(merged_images_val, exist_ok=True)
os.makedirs(merged_labels_train, exist_ok=True)
os.makedirs(merged_labels_val, exist_ok=True)
self._appendLog(f"创建合并目录: {temp_dataset_dir}\n")
# 合并所有数据集
total_train_images = 0
total_val_images = 0
total_train_labels = 0
total_val_labels = 0
for i, folder in enumerate(dataset_folders):
self._appendLog(f"正在合并数据集 {i+1}/{len(dataset_folders)}: {os.path.basename(folder)}\n")
# 检查源目录结构
src_images_train = os.path.join(folder, 'images', 'train')
src_images_val = os.path.join(folder, 'images', 'val')
src_labels_train = os.path.join(folder, 'labels', 'train')
src_labels_val = os.path.join(folder, 'labels', 'val')
# 复制训练图片
if os.path.exists(src_images_train):
for filename in os.listdir(src_images_train):
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')):
src_file = os.path.join(src_images_train, filename)
# 添加前缀避免文件名冲突
dst_filename = f"ds{i+1}_{filename}"
dst_file = os.path.join(merged_images_train, dst_filename)
shutil.copy2(src_file, dst_file)
total_train_images += 1
# 复制验证图片
if os.path.exists(src_images_val):
for filename in os.listdir(src_images_val):
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')):
src_file = os.path.join(src_images_val, filename)
# 添加前缀避免文件名冲突
dst_filename = f"ds{i+1}_{filename}"
dst_file = os.path.join(merged_images_val, dst_filename)
shutil.copy2(src_file, dst_file)
total_val_images += 1
# 复制训练标签
if os.path.exists(src_labels_train):
for filename in os.listdir(src_labels_train):
if filename.lower().endswith('.txt'):
src_file = os.path.join(src_labels_train, filename)
# 添加前缀避免文件名冲突,保持与图片文件名对应
dst_filename = f"ds{i+1}_{filename}"
dst_file = os.path.join(merged_labels_train, dst_filename)
shutil.copy2(src_file, dst_file)
total_train_labels += 1
# 复制验证标签
if os.path.exists(src_labels_val):
for filename in os.listdir(src_labels_val):
if filename.lower().endswith('.txt'):
src_file = os.path.join(src_labels_val, filename)
# 添加前缀避免文件名冲突,保持与图片文件名对应
dst_filename = f"ds{i+1}_{filename}"
dst_file = os.path.join(merged_labels_val, dst_filename)
shutil.copy2(src_file, dst_file)
total_val_labels += 1
self._appendLog(f"合并完成: 训练图片 {total_train_images} 张, 验证图片 {total_val_images} 张\n")
self._appendLog(f"合并完成: 训练标签 {total_train_labels} 个, 验证标签 {total_val_labels} 个\n")
# 创建data.yaml配置文件
data_yaml_path = os.path.join(temp_dataset_dir, 'data.yaml')
data_config = {
'train': os.path.join(temp_dataset_dir, 'images', 'train'),
'val': os.path.join(temp_dataset_dir, 'images', 'val'),
'nc': 1, # 类别数量,液位检测通常是单类别
'names': ['liquid_level'] # 类别名称
}
with open(data_yaml_path, 'w', encoding='utf-8') as f:
yaml.dump(data_config, f, default_flow_style=False, allow_unicode=True)
self._appendLog(f"创建配置文件: {data_yaml_path}\n")
return data_yaml_path
except Exception as e:
self._appendLog(f"数据集合并失败: {str(e)}\n")
import traceback
self._appendLog(traceback.format_exc())
return None
def _createDataYamlForSingleFolder(self, dataset_folder, exp_name):
"""
为单个数据集文件夹创建data.yaml配置文件
Args:
dataset_folder: 数据集文件夹路径
exp_name: 实验名称
Returns:
str: data.yaml文件路径,失败返回None
"""
try:
# 检查是否已经有data.yaml文件
existing_yaml = os.path.join(dataset_folder, 'data.yaml')
if os.path.exists(existing_yaml):
self._appendLog(f"使用现有配置文件: {existing_yaml}\n")
return existing_yaml
# 创建临时配置文件
project_root = get_project_root()
temp_config_dir = os.path.join(project_root, 'database', 'temp_configs')
os.makedirs(temp_config_dir, exist_ok=True)
data_yaml_path = os.path.join(temp_config_dir, f"{exp_name}_data.yaml")
# 检查数据集目录结构
images_train_dir = os.path.join(dataset_folder, 'images', 'train')
images_val_dir = os.path.join(dataset_folder, 'images', 'val')
if not os.path.exists(images_train_dir):
self._appendLog(f"错误: 训练图片目录不存在: {images_train_dir}\n")
return None
# 创建data.yaml配置
data_config = {
'train': images_train_dir,
'val': images_val_dir if os.path.exists(images_val_dir) else images_train_dir,
'nc': 1, # 类别数量,液位检测通常是单类别
'names': ['liquid_level'] # 类别名称
}
with open(data_yaml_path, 'w', encoding='utf-8') as f:
yaml.dump(data_config, f, default_flow_style=False, allow_unicode=True)
self._appendLog(f"创建配置文件: {data_yaml_path}\n")
return data_yaml_path
except Exception as e:
self._appendLog(f"创建配置文件失败: {str(e)}\n")
import traceback
self._appendLog(traceback.format_exc())
return None
...@@ -69,6 +69,7 @@ class TrainingWorker(QThread): ...@@ -69,6 +69,7 @@ class TrainingWorker(QThread):
super().__init__() super().__init__()
self.training_params = training_params self.training_params = training_params
self.is_running = True self.is_running = True
self.training_actually_started = False # 标记训练是否已经真正开始(第一个epoch开始)
self.train_config = None self.train_config = None
# 调试信息:显示传入的训练参数 # 调试信息:显示传入的训练参数
...@@ -522,6 +523,8 @@ class TrainingWorker(QThread): ...@@ -522,6 +523,8 @@ class TrainingWorker(QThread):
def on_train_start(trainer): def on_train_start(trainer):
"""训练开始回调 - 只输出到终端,不发送到UI""" """训练开始回调 - 只输出到终端,不发送到UI"""
# 标记训练已经真正开始
self.training_actually_started = True
# 记录开始时间 # 记录开始时间
epoch_start_time[0] = time.time() epoch_start_time[0] = time.time()
# 不发送任何格式化消息到UI,让LogCapture直接捕获原生输出 # 不发送任何格式化消息到UI,让LogCapture直接捕获原生输出
...@@ -1264,3 +1267,7 @@ class TrainingWorker(QThread): ...@@ -1264,3 +1267,7 @@ class TrainingWorker(QThread):
def stop_training(self): def stop_training(self):
"""停止训练""" """停止训练"""
self.is_running = False self.is_running = False
def has_training_started(self):
"""检查训练是否已经真正开始"""
return self.training_actually_started
...@@ -1900,6 +1900,10 @@ class ChannelPanelHandler: ...@@ -1900,6 +1900,10 @@ class ChannelPanelHandler:
def _initConfigFileWatcher(self): def _initConfigFileWatcher(self):
"""初始化配置文件监控器""" """初始化配置文件监控器"""
try: try:
# 临时禁用配置文件监控器以解决 QWidget 创建顺序问题
print(f"[ConfigWatcher] 配置文件监控器已禁用(避免 QWidget 创建顺序问题)")
return
# 获取配置文件路径 # 获取配置文件路径
project_root = get_project_root() project_root = get_project_root()
config_path = os.path.join(project_root, 'database', 'config', 'default_config.yaml') config_path = os.path.join(project_root, 'database', 'config', 'default_config.yaml')
...@@ -1928,7 +1932,12 @@ class ChannelPanelHandler: ...@@ -1928,7 +1932,12 @@ class ChannelPanelHandler:
print(f"🔄 [ConfigWatcher] 检测到配置文件变化: {path}") print(f"🔄 [ConfigWatcher] 检测到配置文件变化: {path}")
# 延迟一小段时间,确保文件写入完成 # 延迟一小段时间,确保文件写入完成
# 修复:检查 QApplication 是否存在
if QtWidgets.QApplication.instance() is not None:
QtCore.QTimer.singleShot(100, self._reloadChannelConfig) QtCore.QTimer.singleShot(100, self._reloadChannelConfig)
else:
# 如果没有 QApplication,直接调用重载函数
self._reloadChannelConfig()
except Exception as e: except Exception as e:
print(f"[ConfigWatcher] 处理配置文件变化失败: {e}") print(f"[ConfigWatcher] 处理配置文件变化失败: {e}")
......
...@@ -17,7 +17,7 @@ import csv ...@@ -17,7 +17,7 @@ import csv
import datetime import datetime
import numpy as np import numpy as np
from qtpy import QtWidgets, QtCore from qtpy import QtWidgets, QtCore
from PyQt5.QtCore import QThread, pyqtSignal from qtpy.QtCore import QThread, Signal as pyqtSignal
# 导入图标工具 # 导入图标工具
try: try:
...@@ -930,7 +930,7 @@ class CurvePanelHandler: ...@@ -930,7 +930,7 @@ class CurvePanelHandler:
QtWidgets.QApplication.processEvents() QtWidgets.QApplication.processEvents()
# 🔥 延迟关闭进度条,确保用户能看到(至少显示500ms) # 🔥 延迟关闭进度条,确保用户能看到(至少显示500ms)
from PyQt5.QtCore import QTimer from qtpy.QtCore import QTimer
QTimer.singleShot(500, progress_dialog.close) QTimer.singleShot(500, progress_dialog.close)
print(f"✅ [进度条] 将在500ms后关闭") print(f"✅ [进度条] 将在500ms后关闭")
......
...@@ -11,7 +11,7 @@ widgets/videopage/historypanel.py (HistoryPanel) ...@@ -11,7 +11,7 @@ widgets/videopage/historypanel.py (HistoryPanel)
- -
""" """
from PyQt5 import QtWidgets, QtCore from qtpy import QtWidgets, QtCore
class HistoryPanelHandler: class HistoryPanelHandler:
...@@ -80,8 +80,8 @@ class HistoryPanelHandler: ...@@ -80,8 +80,8 @@ class HistoryPanelHandler:
return return
import os import os
from PyQt5.QtMultimedia import QMediaPlayer from qtpy.QtMultimedia import QMediaPlayer
from PyQt5 import QtWidgets from qtpy import QtWidgets
player = self.history_panel._media_player player = self.history_panel._media_player
print(f"[HistoryPanelHandler] 播放器: {player}") print(f"[HistoryPanelHandler] 播放器: {player}")
...@@ -158,7 +158,7 @@ class HistoryPanelHandler: ...@@ -158,7 +158,7 @@ class HistoryPanelHandler:
if not self.history_panel or not hasattr(self.history_panel, 'play_pause_button'): if not self.history_panel or not hasattr(self.history_panel, 'play_pause_button'):
return return
from PyQt5.QtMultimedia import QMediaPlayer from qtpy.QtMultimedia import QMediaPlayer
from widgets.style_manager import newIcon from widgets.style_manager import newIcon
button = self.history_panel.play_pause_button button = self.history_panel.play_pause_button
......
...@@ -699,7 +699,13 @@ class ModelSetPage(QtWidgets.QWidget): ...@@ -699,7 +699,13 @@ class ModelSetPage(QtWidgets.QWidget):
menu.addSeparator() menu.addSeparator()
# 5. 删除模型(有实际功能) # 5. 查看模型信息(新功能)
action_view_info = menu.addAction("查看模型信息")
action_view_info.triggered.connect(lambda: self.viewModelInfo(model_name))
menu.addSeparator()
# 6. 删除模型(有实际功能)
action_delete = menu.addAction("删除模型") action_delete = menu.addAction("删除模型")
action_delete.triggered.connect(lambda: self.deleteModel(model_name)) action_delete.triggered.connect(lambda: self.deleteModel(model_name))
...@@ -735,6 +741,9 @@ class ModelSetPage(QtWidgets.QWidget): ...@@ -735,6 +741,9 @@ class ModelSetPage(QtWidgets.QWidget):
self.setDefaultRequested.emit(model_name) self.setDefaultRequested.emit(model_name)
self._updateStats() self._updateStats()
# 通知训练页面刷新基础模型列表
self._notifyTrainingPageRefresh()
def updateModelParams(self, model_name): def updateModelParams(self, model_name):
"""更新模型参数显示(已移除右侧参数显示,保留方法以保持兼容性)""" """更新模型参数显示(已移除右侧参数显示,保留方法以保持兼容性)"""
# 方法保留以保持API兼容性,但不再执行任何操作 # 方法保留以保持API兼容性,但不再执行任何操作
...@@ -1007,7 +1016,8 @@ class ModelSetPage(QtWidgets.QWidget): ...@@ -1007,7 +1016,8 @@ class ModelSetPage(QtWidgets.QWidget):
self._updateModelOrder() self._updateModelOrder()
# 通知训练页面刷新基础模型列表
self._notifyTrainingPageRefresh()
except Exception as e: except Exception as e:
showCritical(self, "删除失败", f"删除模型时发生错误: {e}") showCritical(self, "删除失败", f"删除模型时发生错误: {e}")
...@@ -1039,6 +1049,216 @@ class ModelSetPage(QtWidgets.QWidget): ...@@ -1039,6 +1049,216 @@ class ModelSetPage(QtWidgets.QWidget):
except Exception as e: except Exception as e:
showCritical(self, "操作失败", f"添加至检测模型时发生错误: {e}") showCritical(self, "操作失败", f"添加至检测模型时发生错误: {e}")
def viewModelInfo(self, model_name):
"""查看模型信息(来源和训练指标)"""
try:
# 获取模型参数
if model_name not in self._model_params:
showWarning(self, "错误", f"未找到模型 '{model_name}' 的信息")
return
model_params = self._model_params[model_name]
model_path = model_params.get('path', '')
if not model_path or not os.path.exists(model_path):
showWarning(self, "错误", f"模型文件不存在: {model_path}")
return
# 读取模型配置和训练指标
model_info = self._readModelTrainingInfo(model_path, model_name)
# 创建并显示信息对话框
self._showModelInfoDialog(model_name, model_params, model_info)
except Exception as e:
import traceback
traceback.print_exc()
showCritical(self, "错误", f"查看模型信息时发生错误: {e}")
def _readModelTrainingInfo(self, model_path, model_name):
"""读取模型的训练配置和指标信息"""
info = {
'config': {},
'metrics': {},
'training_date': None,
'source': '未知'
}
try:
from pathlib import Path
model_file = Path(model_path)
# 1. 确定模型所在目录
if model_file.is_file():
model_dir = model_file.parent
else:
model_dir = model_file
# 2. 尝试读取config.yaml(训练配置)
config_file = model_dir / 'config.yaml'
if not config_file.exists():
# 尝试上一级目录
config_file = model_dir.parent / 'config.yaml'
if config_file.exists():
with open(config_file, 'r', encoding='utf-8') as f:
info['config'] = yaml.safe_load(f) or {}
info['training_date'] = info['config'].get('training_date', '未知')
info['source'] = '训练生成'
# 3. 尝试读取results.csv(训练指标)
results_file = model_dir / 'results.csv'
if not results_file.exists():
# 尝试train子目录
results_file = model_dir.parent / 'train' / 'results.csv'
if results_file.exists():
import pandas as pd
try:
df = pd.read_csv(results_file)
if len(df) > 0:
# 获取最后一行(最终指标)
last_row = df.iloc[-1]
info['metrics'] = {
'epoch': int(last_row.get('epoch', 0)) if 'epoch' in last_row else len(df),
'train_loss': float(last_row.get('train/box_loss', 0)) if 'train/box_loss' in last_row else None,
'val_loss': float(last_row.get('val/box_loss', 0)) if 'val/box_loss' in last_row else None,
'precision': float(last_row.get('metrics/precision(B)', 0)) if 'metrics/precision(B)' in last_row else None,
'recall': float(last_row.get('metrics/recall(B)', 0)) if 'metrics/recall(B)' in last_row else None,
'mAP50': float(last_row.get('metrics/mAP50(B)', 0)) if 'metrics/mAP50(B)' in last_row else None,
'mAP50-95': float(last_row.get('metrics/mAP50-95(B)', 0)) if 'metrics/mAP50-95(B)' in last_row else None,
}
except Exception as e:
print(f"读取results.csv失败: {e}")
# 4. 如果没有配置文件,从文件路径判断来源
if not info['config']:
if 'train_model' in str(model_path):
info['source'] = '本地训练'
elif 'detection_model' in str(model_path):
info['source'] = '检测模型'
else:
info['source'] = '导入模型'
except Exception as e:
import traceback
traceback.print_exc()
print(f"读取模型训练信息失败: {e}")
return info
def _showModelInfoDialog(self, model_name, model_params, model_info):
"""显示模型信息对话框(简化版 - 只显示训练配置)"""
# 创建对话框
dialog = QtWidgets.QDialog(self)
dialog.setWindowTitle(f"模型升级配置 - {model_name}")
dialog.setMinimumWidth(450)
dialog.setMinimumHeight(400)
layout = QtWidgets.QVBoxLayout(dialog)
layout.setSpacing(15)
layout.setContentsMargins(20, 20, 20, 20)
# 标题
title_label = QtWidgets.QLabel(f"<h3>{model_name}</h3>")
layout.addWidget(title_label)
# 创建表单布局显示训练配置
form_widget = QtWidgets.QWidget()
form_layout = QtWidgets.QFormLayout(form_widget)
form_layout.setSpacing(10)
form_layout.setContentsMargins(10, 10, 10, 10)
form_layout.setLabelAlignment(Qt.AlignRight | Qt.AlignVCenter)
# 设置表单样式
form_widget.setStyleSheet("""
QWidget {
background-color: #f5f5f5;
border: 1px solid #ddd;
border-radius: 4px;
}
QLabel {
font-size: 10pt;
padding: 3px;
}
""")
# 从配置中提取训练参数
config = model_info.get('config', {})
# 基础信息
info_label = QtWidgets.QLabel(f"<b>模型来源:</b>{model_info['source']}")
form_layout.addRow("", info_label)
if model_info.get('training_date'):
date_label = QtWidgets.QLabel(f"<b>训练日期:</b>{model_info['training_date']}")
form_layout.addRow("", date_label)
# 添加分隔线
separator1 = QtWidgets.QFrame()
separator1.setFrameShape(QtWidgets.QFrame.HLine)
separator1.setFrameShadow(QtWidgets.QFrame.Sunken)
form_layout.addRow(separator1)
# 训练配置参数(仿照训练页面的格式)
epochs = config.get('epochs', '未知')
batch_size = config.get('batch', config.get('batch_size', '未知'))
imgsz = config.get('imgsz', '未知')
workers = config.get('workers', '未知')
device = config.get('device', '未知')
optimizer = config.get('optimizer', '未知')
form_layout.addRow("训练轮数:", QtWidgets.QLabel(f"<b>{epochs} 轮</b>"))
form_layout.addRow("批次大小:", QtWidgets.QLabel(f"<b>{batch_size}</b>"))
form_layout.addRow("图像尺寸:", QtWidgets.QLabel(f"<b>{imgsz} px</b>"))
form_layout.addRow("Workers:", QtWidgets.QLabel(f"<b>{workers} 线程</b>"))
form_layout.addRow("训练设备:", QtWidgets.QLabel(f"<b>{device}</b>"))
form_layout.addRow("优化器:", QtWidgets.QLabel(f"<b>{optimizer}</b>"))
# 如果有训练指标,显示最终性能
metrics = model_info.get('metrics', {})
if metrics:
separator2 = QtWidgets.QFrame()
separator2.setFrameShape(QtWidgets.QFrame.HLine)
separator2.setFrameShadow(QtWidgets.QFrame.Sunken)
form_layout.addRow(separator2)
mAP50 = metrics.get('mAP50')
mAP5095 = metrics.get('mAP50-95')
if mAP50 is not None:
form_layout.addRow("mAP@0.5:", QtWidgets.QLabel(f"<b style='color: #28a745;'>{mAP50:.4f}</b>"))
if mAP5095 is not None:
form_layout.addRow("mAP@0.5:0.95:", QtWidgets.QLabel(f"<b style='color: #28a745;'>{mAP5095:.4f}</b>"))
layout.addWidget(form_widget)
# 如果没有训练配置,显示提示
if not config:
no_config_label = QtWidgets.QLabel(
"<i>该模型没有训练配置信息<br>"
"可能是外部导入的模型</i>"
)
no_config_label.setAlignment(Qt.AlignCenter)
no_config_label.setStyleSheet("color: #999; padding: 20px;")
layout.addWidget(no_config_label)
layout.addStretch()
# 按钮
button_layout = QtWidgets.QHBoxLayout()
button_layout.addStretch()
close_btn = QtWidgets.QPushButton("关闭")
close_btn.setMinimumWidth(100)
close_btn.clicked.connect(dialog.accept)
button_layout.addWidget(close_btn)
layout.addLayout(button_layout)
# 显示对话框
dialog.exec_()
def _moveModelToDetection(self, model_name, source_path): def _moveModelToDetection(self, model_name, source_path):
"""执行模型移动操作""" """执行模型移动操作"""
try: try:
...@@ -1149,10 +1369,25 @@ class ModelSetPage(QtWidgets.QWidget): ...@@ -1149,10 +1369,25 @@ class ModelSetPage(QtWidgets.QWidget):
self._updateModelOrder() self._updateModelOrder()
# 通知训练页面刷新基础模型列表
self._notifyTrainingPageRefresh()
except Exception as e: except Exception as e:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def _notifyTrainingPageRefresh(self):
"""通知训练页面刷新基础模型列表"""
try:
# 尝试获取训练页面并刷新其基础模型列表
if self._parent and hasattr(self._parent, 'trainingPage'):
training_page = self._parent.trainingPage
if hasattr(training_page, '_loadBaseModelOptions'):
training_page._loadBaseModelOptions()
print("[模型集管理] 已通知训练页面刷新基础模型列表")
except Exception as e:
print(f"[模型集管理] 通知训练页面刷新失败: {e}")
def _loadConfigFile(self): def _loadConfigFile(self):
"""加载配置文件""" """加载配置文件"""
try: try:
...@@ -1208,49 +1443,121 @@ class ModelSetPage(QtWidgets.QWidget): ...@@ -1208,49 +1443,121 @@ class ModelSetPage(QtWidgets.QWidget):
return models return models
def _scanModelDirectory(self): def _scanModelDirectory(self):
"""扫描模型目录获取所有模型文件""" """扫描模型目录获取所有模型文件(优先detection_model)"""
models = [] models = []
try: try:
# 获取模型目录路径 # 获取模型目录路径
current_dir = Path(__file__).parent.parent.parent current_dir = Path(__file__).parent.parent.parent
# 扫描多个模型目录 # 扫描多个模型目录(优先detection_model)
model_dirs = [ model_dirs = [
(current_dir / "database" / "model" / "detection_model", "检测模型"), (current_dir / "database" / "model" / "detection_model", "检测模型", True), # 优先级最高
(current_dir / "database" / "model" / "train_model", "训练模型"), (current_dir / "database" / "model" / "train_model", "训练模型", False),
(current_dir / "database" / "model" / "test_model", "测试模型") (current_dir / "database" / "model" / "test_model", "测试模型", False)
] ]
for model_dir, dir_type in model_dirs: for model_dir, dir_type, is_primary in model_dirs:
if not model_dir.exists(): if not model_dir.exists():
continue continue
# 遍历所有子目录,并按目录名排序(确保模型1最先) # 遍历所有子目录,按数字排序(降序,最新的在前)
sorted_subdirs = sorted(model_dir.iterdir(), key=lambda x: x.name if x.is_dir() else '') all_subdirs = [d for d in model_dir.iterdir() if d.is_dir()]
digit_subdirs = [d for d in all_subdirs if d.name.isdigit()]
sorted_subdirs = sorted(digit_subdirs, key=lambda x: int(x.name), reverse=True)
for subdir in sorted_subdirs: for subdir in sorted_subdirs:
if subdir.is_dir(): # 检查是否有weights子目录
# 查找 .dat 文件(优先) weights_dir = subdir / "weights"
for model_file in sorted(subdir.glob("*.dat")): search_dir = weights_dir if weights_dir.exists() else subdir
models.append({
'name': f"{dir_type}-{subdir.name}-{model_file.stem}", # 尝试读取config.yaml获取模型名称
'path': str(model_file), config_file = subdir / "config.yaml"
'subdir': subdir.name, model_display_name = None
'source': 'scan', if config_file.exists():
'format': 'dat', try:
'model_type': dir_type import yaml
}) with open(config_file, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
if config_data and 'name' in config_data:
model_display_name = config_data['name']
except Exception:
pass
# 如果没有配置文件中的名称,使用默认命名
if not model_display_name:
if is_primary:
model_display_name = f"模型-{subdir.name}"
else:
model_display_name = f"{dir_type}-{subdir.name}"
# 按优先级查找模型文件:best > last > epoch1
selected_model = None
# 优先级1: best模型(.dat优先)
for ext in ['.dat', '']: # 无扩展名的也考虑
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('best.'):
if ext == '' and '.' in file.name[5:]: # 有其他扩展名
continue
if ext != '' and not file.name.endswith(ext):
continue
if ext == '' or file.name.endswith(ext):
selected_model = file
break
if selected_model:
break
# 优先级2: last模型
if not selected_model:
for ext in ['.dat', '']:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('last.'):
if ext == '' and '.' in file.name[5:]:
continue
if ext != '' and not file.name.endswith(ext):
continue
if ext == '' or file.name.endswith(ext):
selected_model = file
break
if selected_model:
break
# 优先级3: epoch1模型
if not selected_model:
for ext in ['.dat', '']:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('epoch1.'):
if ext == '' and '.' in file.name[7:]:
continue
if ext != '' and not file.name.endswith(ext):
continue
if ext == '' or file.name.endswith(ext):
selected_model = file
break
if selected_model:
break
# 如果找到了模型文件,添加到列表
if selected_model:
# 获取文件格式
file_ext = selected_model.suffix.lstrip('.')
if not file_ext:
# 处理无扩展名的情况
if '.' in selected_model.name:
file_ext = selected_model.name.split('.')[-1]
else:
file_ext = 'dat' # 默认为dat格式
# 然后查找 .pt 文件
for model_file in sorted(subdir.glob("*.pt")):
models.append({ models.append({
'name': f"{dir_type}-{subdir.name}-{model_file.stem}", 'name': model_display_name,
'path': str(model_file), 'path': str(selected_model),
'subdir': subdir.name, 'subdir': subdir.name,
'source': 'scan', 'source': 'scan',
'format': 'pt', 'format': file_ext,
'model_type': dir_type 'model_type': dir_type,
'is_primary': is_primary, # 标记是否为主要模型目录
'file_name': selected_model.name
}) })
except Exception as e: except Exception as e:
...@@ -1260,29 +1567,47 @@ class ModelSetPage(QtWidgets.QWidget): ...@@ -1260,29 +1567,47 @@ class ModelSetPage(QtWidgets.QWidget):
return models return models
def _mergeModelInfo(self, channel_models, scanned_models): def _mergeModelInfo(self, channel_models, scanned_models):
"""合并模型信息,避免重复""" """合并模型信息,避免重复(优先detection_model)"""
all_models = [] all_models = []
seen_paths = set() seen_paths = set()
# 优先添加配置文件中的通道模型 # 首先添加detection_model中的主要模型(优先级最高)
primary_models = [m for m in scanned_models if m.get('is_primary', False)]
for model in primary_models:
path = model['path']
if path not in seen_paths:
all_models.append(model)
seen_paths.add(path)
# 然后添加配置文件中的通道模型
for model in channel_models: for model in channel_models:
path = model['path'] path = model['path']
if path not in seen_paths: if path not in seen_paths:
all_models.append(model) all_models.append(model)
seen_paths.add(path) seen_paths.add(path)
# 再添加扫描到的模型(跳过已存在的) # 最后添加其他扫描到的模型(跳过已存在的)
for model in scanned_models: other_models = [m for m in scanned_models if not m.get('is_primary', False)]
for model in other_models:
path = model['path'] path = model['path']
if path not in seen_paths: if path not in seen_paths:
all_models.append(model) all_models.append(model)
seen_paths.add(path) seen_paths.add(path)
# 确保有一个默认模型:如果没有默认模型,将第一个模型设为默认 # 确保有一个默认模型:优先选择detection_model中的第一个模型
has_default = any(model.get('is_default', False) for model in all_models) has_default = any(model.get('is_default', False) for model in all_models)
if not has_default and len(all_models) > 0: if not has_default and len(all_models) > 0:
# 优先选择detection_model中的模型作为默认
primary_model_found = False
for model in all_models:
if model.get('is_primary', False):
model['is_default'] = True
primary_model_found = True
break
# 如果没有detection_model中的模型,选择第一个
if not primary_model_found:
all_models[0]['is_default'] = True all_models[0]['is_default'] = True
pass
return all_models return all_models
......
...@@ -11,6 +11,14 @@ from pathlib import Path ...@@ -11,6 +11,14 @@ from pathlib import Path
from qtpy import QtWidgets, QtCore, QtGui from qtpy import QtWidgets, QtCore, QtGui
from qtpy.QtCore import Qt from qtpy.QtCore import Qt
# 尝试导入 PyQtGraph 用于曲线显示
try:
import pyqtgraph as pg
PYQTGRAPH_AVAILABLE = True
except ImportError:
pg = None
PYQTGRAPH_AVAILABLE = False
# 导入图标工具函数 # 导入图标工具函数
try: try:
from ..icons import newIcon, newButton from ..icons import newIcon, newButton
...@@ -36,11 +44,11 @@ except (ImportError, ValueError): ...@@ -36,11 +44,11 @@ except (ImportError, ValueError):
# 导入样式管理器和响应式布局 # 导入样式管理器和响应式布局
try: try:
from ..style_manager import FontManager, BackgroundStyleManager from ..style_manager import FontManager, BackgroundStyleManager, TextButtonStyleManager
from ..responsive_layout import ResponsiveLayout, scale_w, scale_h from ..responsive_layout import ResponsiveLayout, scale_w, scale_h
except (ImportError, ValueError): except (ImportError, ValueError):
try: try:
from widgets.style_manager import FontManager, BackgroundStyleManager from widgets.style_manager import FontManager, BackgroundStyleManager, TextButtonStyleManager
from widgets.responsive_layout import ResponsiveLayout, scale_w, scale_h from widgets.responsive_layout import ResponsiveLayout, scale_w, scale_h
except ImportError: except ImportError:
try: try:
...@@ -48,7 +56,7 @@ except (ImportError, ValueError): ...@@ -48,7 +56,7 @@ except (ImportError, ValueError):
from pathlib import Path from pathlib import Path
project_root = Path(__file__).parent.parent.parent project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
from widgets.style_manager import FontManager, BackgroundStyleManager from widgets.style_manager import FontManager, BackgroundStyleManager, TextButtonStyleManager
from widgets.responsive_layout import ResponsiveLayout, scale_w, scale_h from widgets.responsive_layout import ResponsiveLayout, scale_w, scale_h
except ImportError: except ImportError:
# 如果导入失败,创建一个简单的替代类 # 如果导入失败,创建一个简单的替代类
...@@ -96,8 +104,9 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -96,8 +104,9 @@ class TrainingPage(QtWidgets.QWidget):
# 🔥 连接模板按钮组信号 # 🔥 连接模板按钮组信号
self.template_button_group.buttonClicked.connect(self._onTemplateChecked) self.template_button_group.buttonClicked.connect(self._onTemplateChecked)
self._loadBaseModelOptions() # 🔥 加载基础模型选项
self._loadTestModelOptions() # 加载测试模型选项 self._loadTestModelOptions() # 加载测试模型选项
self._loadTestFileList() # 🔥 加载测试文件列表 # self._loadTestFileList() # 🔥 不再需要加载测试文件列表(改用浏览方式)
def _increaseFontSize(self): def _increaseFontSize(self):
"""增加日志字体大小""" """增加日志字体大小"""
...@@ -251,6 +260,10 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -251,6 +260,10 @@ class TrainingPage(QtWidgets.QWidget):
# 将视频面板添加到显示布局中 # 将视频面板添加到显示布局中
display_layout.addWidget(self.video_panel) display_layout.addWidget(self.video_panel)
# === 添加曲线显示面板 ===
self._createCurvePanel()
display_layout.addWidget(self.curve_panel)
# 🔥 设置整体窗口的最小尺寸 - 使用响应式布局 # 🔥 设置整体窗口的最小尺寸 - 使用响应式布局
min_w, min_h = scale_w(1000), scale_h(700) min_w, min_h = scale_w(1000), scale_h(700)
self.setMinimumSize(min_w, min_h) self.setMinimumSize(min_w, min_h)
...@@ -311,24 +324,38 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -311,24 +324,38 @@ class TrainingPage(QtWidgets.QWidget):
test_file_label.setStyleSheet("color: #495057; font-weight: bold;") test_file_label.setStyleSheet("color: #495057; font-weight: bold;")
test_file_layout.addWidget(test_file_label) test_file_layout.addWidget(test_file_label)
self.test_file_input = QtWidgets.QComboBox() # 测试文件路径输入框和浏览按钮
test_file_input_layout = QtWidgets.QHBoxLayout()
test_file_input_layout.setSpacing(scale_spacing(5))
# 文件路径输入框(可编辑)
self.test_file_input = QtWidgets.QLineEdit()
self.test_file_input.setPlaceholderText("选择测试图片或视频文件...")
FontManager.applyToWidget(self.test_file_input) FontManager.applyToWidget(self.test_file_input)
self.test_file_input.setStyleSheet(""" self.test_file_input.setStyleSheet("""
QComboBox { QLineEdit {
padding: 6px 10px; padding: 6px 10px;
border: 1px solid #ced4da; border: 1px solid #ced4da;
border-radius: 4px; border-radius: 4px;
background-color: white; background-color: white;
} }
QComboBox:focus { QLineEdit:focus {
border-color: #0078d7; border-color: #0078d7;
outline: none; outline: none;
} }
QComboBox::drop-down {
border: none;
}
""") """)
test_file_layout.addWidget(self.test_file_input) test_file_input_layout.addWidget(self.test_file_input)
# 浏览按钮(使用全局样式管理器)
self.test_file_browse_btn = TextButtonStyleManager.createStandardButton(
"浏览...",
parent=self,
slot=self._browseTestFile
)
self.test_file_browse_btn.setMinimumWidth(scale_w(60))
test_file_input_layout.addWidget(self.test_file_browse_btn)
test_file_layout.addLayout(test_file_input_layout)
right_control_layout.addLayout(test_file_layout) right_control_layout.addLayout(test_file_layout)
# 添加垂直间距 # 添加垂直间距
...@@ -337,14 +364,32 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -337,14 +364,32 @@ class TrainingPage(QtWidgets.QWidget):
test_button_layout = QtWidgets.QHBoxLayout() test_button_layout = QtWidgets.QHBoxLayout()
ResponsiveLayout.apply_to_layout(test_button_layout, base_spacing=10, base_margins=0) ResponsiveLayout.apply_to_layout(test_button_layout, base_spacing=10, base_margins=0)
self.start_annotation_btn = QtWidgets.QPushButton("开始标注") # 使用全局样式管理器创建测试按钮
self.start_annotation_btn = TextButtonStyleManager.createStandardButton(
"开始标注",
parent=self
)
self.start_annotation_btn.setMinimumWidth(scale_w(80)) self.start_annotation_btn.setMinimumWidth(scale_w(80))
test_button_layout.addWidget(self.start_annotation_btn) test_button_layout.addWidget(self.start_annotation_btn)
self.start_test_btn = QtWidgets.QPushButton("开始测试") self.start_test_btn = TextButtonStyleManager.createStandardButton(
"开始测试",
parent=self
)
self.start_test_btn.setMinimumWidth(scale_w(80)) self.start_test_btn.setMinimumWidth(scale_w(80))
test_button_layout.addWidget(self.start_test_btn) test_button_layout.addWidget(self.start_test_btn)
# 查看曲线按钮(使用全局样式管理器)
self.view_curve_btn = TextButtonStyleManager.createStandardButton(
"查看曲线",
parent=self,
slot=self._onViewCurveClicked
)
self.view_curve_btn.setMinimumWidth(scale_w(80))
self.view_curve_btn.setEnabled(False) # 初始状态禁用
self.view_curve_btn.setToolTip("测试完成后可查看曲线结果")
test_button_layout.addWidget(self.view_curve_btn)
# 将按钮布局添加到主布局 # 将按钮布局添加到主布局
right_control_layout.addLayout(test_button_layout) right_control_layout.addLayout(test_button_layout)
...@@ -389,29 +434,37 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -389,29 +434,37 @@ class TrainingPage(QtWidgets.QWidget):
FontManager.applyToWidget(self.auto_scroll_check) FontManager.applyToWidget(self.auto_scroll_check)
log_header.addWidget(self.auto_scroll_check) log_header.addWidget(self.auto_scroll_check)
# 清空日志按钮(使用系统默认样式 + 响应式布局) # 清空日志按钮(使用全局样式管理器)
clear_log_btn = QtWidgets.QPushButton("清空日志") clear_log_btn = TextButtonStyleManager.createStandardButton(
"清空日志",
parent=self,
slot=self.clearLog
)
clear_log_btn.setMinimumWidth(scale_w(60)) clear_log_btn.setMinimumWidth(scale_w(60))
# 🔥 添加字体大小调整按钮 # 🔥 添加字体大小调整按钮
font_size_label = QtWidgets.QLabel("字体:") font_size_label = QtWidgets.QLabel("字体:")
# 使用系统默认样式
FontManager.applyToWidget(font_size_label) FontManager.applyToWidget(font_size_label)
log_header.addWidget(font_size_label) log_header.addWidget(font_size_label)
self.font_decrease_btn = QtWidgets.QPushButton("-") # 字体调整按钮(使用全局样式管理器)
self.font_decrease_btn = TextButtonStyleManager.createStandardButton(
"-",
parent=self,
slot=self._decreaseFontSize
)
btn_size = scale_w(20) # 响应式按钮尺寸 btn_size = scale_w(20) # 响应式按钮尺寸
self.font_decrease_btn.setFixedSize(btn_size, btn_size) self.font_decrease_btn.setFixedSize(btn_size, btn_size)
# 使用系统默认样式
self.font_decrease_btn.clicked.connect(self._decreaseFontSize)
log_header.addWidget(self.font_decrease_btn) log_header.addWidget(self.font_decrease_btn)
self.font_increase_btn = QtWidgets.QPushButton("+") self.font_increase_btn = TextButtonStyleManager.createStandardButton(
"+",
parent=self,
slot=self._increaseFontSize
)
self.font_increase_btn.setFixedSize(btn_size, btn_size) self.font_increase_btn.setFixedSize(btn_size, btn_size)
# 使用系统默认样式
self.font_increase_btn.clicked.connect(self._increaseFontSize)
log_header.addWidget(self.font_increase_btn) log_header.addWidget(self.font_increase_btn)
clear_log_btn.clicked.connect(self.clearLog)
log_header.addWidget(clear_log_btn) log_header.addWidget(clear_log_btn)
right_layout.addLayout(log_header) right_layout.addLayout(log_header)
...@@ -467,6 +520,199 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -467,6 +520,199 @@ class TrainingPage(QtWidgets.QWidget):
pass pass
def _createCurvePanel(self):
"""创建曲线显示面板"""
self.curve_panel = QtWidgets.QWidget()
curve_layout = QtWidgets.QVBoxLayout(self.curve_panel)
curve_layout.setContentsMargins(5, 5, 5, 5)
curve_layout.setSpacing(5)
# 曲线面板标题
curve_title_layout = QtWidgets.QHBoxLayout()
curve_title = QtWidgets.QLabel("测试结果曲线")
curve_title.setStyleSheet("color: #495057; font-weight: bold; font-size: 12pt;")
FontManager.applyToWidget(curve_title, weight=FontManager.WEIGHT_BOLD)
curve_title_layout.addWidget(curve_title)
# 添加清空曲线按钮(使用全局样式管理器)
self.clear_curve_btn = TextButtonStyleManager.createStandardButton(
"清空曲线",
parent=self,
slot=self._clearCurve
)
self.clear_curve_btn.setMinimumWidth(scale_w(80))
curve_title_layout.addStretch()
curve_title_layout.addWidget(self.clear_curve_btn)
curve_layout.addLayout(curve_title_layout)
# 检查 PyQtGraph 是否可用
if PYQTGRAPH_AVAILABLE:
# 创建 PyQtGraph 绘图控件
self.curve_plot_widget = pg.PlotWidget()
self.curve_plot_widget.setBackground('#ffffff')
self.curve_plot_widget.showGrid(x=True, y=True, alpha=0.3)
# 设置坐标轴标签
self.curve_plot_widget.setLabel('left', '液位高度', units='mm')
self.curve_plot_widget.setLabel('bottom', '帧序号')
self.curve_plot_widget.setTitle('液位检测曲线', color='#495057', size='12pt')
# 添加图例
self.curve_plot_widget.addLegend()
# 存储曲线数据
self.curve_data_x = [] # X轴数据(帧序号)
self.curve_data_y = [] # Y轴数据(液位高度)
self.curve_line = None # 曲线对象
curve_layout.addWidget(self.curve_plot_widget)
else:
# PyQtGraph 不可用,显示提示信息
placeholder = QtWidgets.QLabel(
"曲线显示功能需要 PyQtGraph 库\n\n"
"请安装: pip install pyqtgraph"
)
placeholder.setAlignment(Qt.AlignCenter)
placeholder.setStyleSheet("color: #999; font-size: 11pt; padding: 50px;")
FontManager.applyToWidget(placeholder)
curve_layout.addWidget(placeholder)
def _clearCurve(self):
"""清空曲线数据并隐藏曲线面板"""
if PYQTGRAPH_AVAILABLE and hasattr(self, 'curve_plot_widget'):
# 清空数据
self.curve_data_x = []
self.curve_data_y = []
# 清空曲线
if self.curve_line:
self.curve_plot_widget.removeItem(self.curve_line)
self.curve_line = None
print("[曲线] 已清空曲线数据")
# 自动隐藏曲线面板,返回到初始显示状态
self.hideCurvePanel()
print("[曲线] 曲线面板已隐藏,返回初始显示状态")
def addCurvePoint(self, frame_index, height_mm):
"""添加曲线数据点
Args:
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
if not PYQTGRAPH_AVAILABLE or not hasattr(self, 'curve_plot_widget'):
return
try:
# 添加数据点
self.curve_data_x.append(frame_index)
self.curve_data_y.append(height_mm)
# 如果曲线不存在,创建曲线
if self.curve_line is None:
self.curve_line = self.curve_plot_widget.plot(
self.curve_data_x,
self.curve_data_y,
pen=pg.mkPen(color='#1f77b4', width=2),
name='液位高度'
)
else:
# 更新曲线数据
self.curve_line.setData(self.curve_data_x, self.curve_data_y)
except Exception as e:
print(f"[曲线] 添加数据点失败: {e}")
def showCurvePanel(self):
"""显示曲线面板"""
if hasattr(self, 'display_layout'):
# 切换到曲线面板(索引3:hint, display_panel, video_panel, curve_panel)
self.display_layout.setCurrentIndex(3)
def hideCurvePanel(self):
"""隐藏曲线面板,返回到显示面板"""
if hasattr(self, 'display_layout'):
self.display_layout.setCurrentIndex(1) # 显示 display_panel
def saveCurveData(self, csv_path):
"""保存曲线数据为CSV文件
Args:
csv_path: CSV文件保存路径
Returns:
bool: 是否成功保存
"""
if not PYQTGRAPH_AVAILABLE or not hasattr(self, 'curve_data_x'):
return False
try:
import csv
# 检查是否有数据
if len(self.curve_data_x) == 0:
print("[曲线保存] 没有曲线数据可保存")
return False
# 写入CSV文件
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# 写入表头
writer.writerow(['帧序号', '液位高度(mm)'])
# 写入数据
for x, y in zip(self.curve_data_x, self.curve_data_y):
writer.writerow([x, y])
print(f"[曲线保存] CSV数据已保存: {csv_path}")
print(f"[曲线保存] 共保存 {len(self.curve_data_x)} 个数据点")
return True
except Exception as e:
print(f"[曲线保存] 保存CSV失败: {e}")
import traceback
traceback.print_exc()
return False
def saveCurveImage(self, image_path):
"""保存曲线为图片文件
Args:
image_path: 图片文件保存路径(支持 .png, .jpg, .svg等格式)
Returns:
bool: 是否成功保存
"""
if not PYQTGRAPH_AVAILABLE or not hasattr(self, 'curve_plot_widget'):
return False
try:
# 检查是否有数据
if len(self.curve_data_x) == 0:
print("[曲线保存] 没有曲线数据可保存")
return False
# 使用PyQtGraph的导出功能
exporter = pg.exporters.ImageExporter(self.curve_plot_widget.plotItem)
# 设置导出参数
exporter.parameters()['width'] = 1200 # 设置宽度
exporter.parameters()['height'] = 600 # 设置高度
# 导出图片
exporter.export(image_path)
print(f"[曲线保存] 曲线图片已保存: {image_path}")
return True
except Exception as e:
print(f"[曲线保存] 保存图片失败: {e}")
import traceback
traceback.print_exc()
return False
def _createParametersGroup(self): def _createParametersGroup(self):
"""创建参数配置组""" """创建参数配置组"""
group = QtWidgets.QGroupBox("") group = QtWidgets.QGroupBox("")
...@@ -502,38 +748,53 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -502,38 +748,53 @@ class TrainingPage(QtWidgets.QWidget):
# 创建表单布局 # 创建表单布局
layout = QtWidgets.QFormLayout() layout = QtWidgets.QFormLayout()
from ..responsive_layout import scale_spacing from ..responsive_layout import scale_spacing
layout.setSpacing(scale_spacing(8)) layout.setSpacing(8)
layout.setContentsMargins(0, 0, 0, 0) layout.setContentsMargins(0, 0, 0, 0)
layout.setLabelAlignment(Qt.AlignRight | Qt.AlignVCenter) layout.setLabelAlignment(Qt.AlignRight | Qt.AlignVCenter)
# 🔥 删除CSS样式表,改为纯控件方式,由字体管理器统一管理 # 🔥 删除CSS样式表,改为纯控件方式,由字体管理器统一管理
# 基础模型路径 # 基础模型选择(下拉菜单)
model_layout = QtWidgets.QHBoxLayout() self.base_model_combo = QtWidgets.QComboBox()
from ..responsive_layout import scale_spacing self.base_model_combo.setPlaceholderText("请选择基础模型")
model_layout.setSpacing(scale_spacing(6)) self.base_model_combo.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
self.base_model_edit = QtWidgets.QLineEdit() layout.addRow("基础模型:", self.base_model_combo)
self.base_model_edit.setPlaceholderText("选择基础模型文件 (.pt)")
self.browse_model_btn = QtWidgets.QPushButton("浏览...") # 数据集文件夹选择(统一格式,支持多个文件夹)
self.browse_model_btn.setFixedWidth(scale_w(80)) dataset_layout = QtWidgets.QHBoxLayout()
self.browse_model_btn.clicked.connect(self._browseModel) dataset_layout.setSpacing(8)
model_layout.addWidget(self.base_model_edit, 1) dataset_layout.setContentsMargins(0, 0, 0, 0)
model_layout.addWidget(self.browse_model_btn)
layout.addRow("基础模型:", model_layout) # 数据集路径显示文本框(可编辑,支持多个路径用分号分隔)
self.dataset_paths_edit = QtWidgets.QLineEdit()
# 数据集路径 (字段名必须是 save_liquid_data_path_edit 以匹配训练处理器) self.dataset_paths_edit.setPlaceholderText("")
data_layout = QtWidgets.QHBoxLayout() self.dataset_paths_edit.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
from ..responsive_layout import scale_spacing # 移除自定义样式,使用全局字体管理器统一管理
data_layout.setSpacing(scale_spacing(6)) self.dataset_paths_edit.setStyleSheet("")
# 连接文本变化信号,实时更新内部数据集列表
self.dataset_paths_edit.textChanged.connect(self._onDatasetPathsChanged)
dataset_layout.addWidget(self.dataset_paths_edit)
# 浏览按钮(使用全局样式管理器)
self.btn_browse_datasets = TextButtonStyleManager.createStandardButton(
"浏览...",
parent=self,
slot=self._onBrowseDatasets
)
self.btn_browse_datasets.setFixedWidth(100)
self.btn_browse_datasets.setSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)
dataset_layout.addWidget(self.btn_browse_datasets)
layout.addRow("数据集:", dataset_layout)
# 保持内部数据集文件夹列表(隐藏,用于兼容现有逻辑)
self.dataset_folders_list = QtWidgets.QListWidget()
self.dataset_folders_list.setVisible(False) # 隐藏,仅用于内部数据管理
# 保留旧的字段名以保持向后兼容(用于获取数据集路径)
# 现在它将存储用分号分隔的多个文件夹路径
self.save_liquid_data_path_edit = QtWidgets.QLineEdit() self.save_liquid_data_path_edit = QtWidgets.QLineEdit()
self.save_liquid_data_path_edit.setPlaceholderText("选择数据集配置文件 (data.yaml)") self.save_liquid_data_path_edit.setVisible(False) # 隐藏,仅用于数据传递
self.save_liquid_data_path_edit.setText("database/dataset/data.yaml") # 默认路径
self.browse_data_btn = QtWidgets.QPushButton("浏览...")
self.browse_data_btn.setFixedWidth(scale_w(80))
self.browse_data_btn.clicked.connect(self._browseDataset)
data_layout.addWidget(self.save_liquid_data_path_edit, 1)
data_layout.addWidget(self.browse_data_btn)
layout.addRow("数据集:", data_layout)
# 实验名称 # 实验名称
self.exp_name_edit = QtWidgets.QLineEdit() self.exp_name_edit = QtWidgets.QLineEdit()
...@@ -586,6 +847,18 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -586,6 +847,18 @@ class TrainingPage(QtWidgets.QWidget):
self.optimizer_combo.addItems(["SGD", "Adam", "AdamW"]) self.optimizer_combo.addItems(["SGD", "Adam", "AdamW"])
layout.addRow("优化器:", self.optimizer_combo) layout.addRow("优化器:", self.optimizer_combo)
# 训练笔记按钮(使用全局样式管理器)
self.training_notes_btn = TextButtonStyleManager.createStandardButton(
"训练笔记",
parent=self,
slot=self._openNotesDialog
)
self.training_notes_btn.setToolTip("点击打开训练笔记编辑窗口")
layout.addRow("训练笔记:", self.training_notes_btn)
# 内部存储笔记内容的变量
self._training_notes_content = ""
# 高级选项分隔 # 高级选项分隔
separator = QtWidgets.QLabel() separator = QtWidgets.QLabel()
separator.setStyleSheet("border-top: 1px solid #dee2e6; margin: 5px 0;") separator.setStyleSheet("border-top: 1px solid #dee2e6; margin: 5px 0;")
...@@ -646,12 +919,18 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -646,12 +919,18 @@ class TrainingPage(QtWidgets.QWidget):
# 创建别名以匹配训练处理器期望的名称 # 创建别名以匹配训练处理器期望的名称
self.train_status_label = self.status_label self.train_status_label = self.status_label
# 控制按钮(使用Qt默认样式 + 响应式布局) # 控制按钮(使用全局样式管理器)
self.start_train_btn = QtWidgets.QPushButton("开始升级") self.start_train_btn = TextButtonStyleManager.createStandardButton(
"开始升级",
parent=self
)
self.start_train_btn.setMinimumWidth(scale_w(80)) self.start_train_btn.setMinimumWidth(scale_w(80))
control_layout.addWidget(self.start_train_btn) control_layout.addWidget(self.start_train_btn)
self.stop_train_btn = QtWidgets.QPushButton("停止升级") self.stop_train_btn = TextButtonStyleManager.createStandardButton(
"停止升级",
parent=self
)
self.stop_train_btn.setMinimumWidth(scale_w(80)) self.stop_train_btn.setMinimumWidth(scale_w(80))
self.stop_train_btn.setEnabled(False) self.stop_train_btn.setEnabled(False)
control_layout.addWidget(self.stop_train_btn) control_layout.addWidget(self.stop_train_btn)
...@@ -663,7 +942,7 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -663,7 +942,7 @@ class TrainingPage(QtWidgets.QWidget):
# 🔥 应用全局字体管理器到所有文本框和控件 # 🔥 应用全局字体管理器到所有文本框和控件
if FontManager: if FontManager:
# 应用到所有QLineEdit # 应用到所有QLineEdit
FontManager.applyToWidget(self.base_model_edit) FontManager.applyToWidget(self.dataset_paths_edit)
FontManager.applyToWidget(self.save_liquid_data_path_edit) FontManager.applyToWidget(self.save_liquid_data_path_edit)
FontManager.applyToWidget(self.exp_name_edit) FontManager.applyToWidget(self.exp_name_edit)
...@@ -674,6 +953,7 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -674,6 +953,7 @@ class TrainingPage(QtWidgets.QWidget):
FontManager.applyToWidget(self.workers_spin) FontManager.applyToWidget(self.workers_spin)
# 应用到所有QComboBox # 应用到所有QComboBox
FontManager.applyToWidget(self.base_model_combo)
FontManager.applyToWidget(self.device_combo) FontManager.applyToWidget(self.device_combo)
FontManager.applyToWidget(self.optimizer_combo) FontManager.applyToWidget(self.optimizer_combo)
...@@ -686,39 +966,181 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -686,39 +966,181 @@ class TrainingPage(QtWidgets.QWidget):
FontManager.applyToWidget(self.checkbox_template_2) FontManager.applyToWidget(self.checkbox_template_2)
FontManager.applyToWidget(self.checkbox_template_3) FontManager.applyToWidget(self.checkbox_template_3)
# 应用到所有QPushButton # 应用到状态标签(非按钮样式管理器创建的控件)
FontManager.applyToWidget(self.browse_model_btn)
FontManager.applyToWidget(self.browse_data_btn)
FontManager.applyToWidget(self.status_label) FontManager.applyToWidget(self.status_label)
FontManager.applyToWidget(self.start_train_btn)
FontManager.applyToWidget(self.stop_train_btn) # 应用到数据集列表
FontManager.applyToWidget(self.dataset_folders_list)
# 应用到笔记按钮
FontManager.applyToWidget(self.training_notes_btn)
# 应用到整个group(包括标签) # 应用到整个group(包括标签)
FontManager.applyToWidgetRecursive(group) FontManager.applyToWidgetRecursive(group)
return group return group
def _browseModel(self): def _onBrowseDatasets(self):
"""浏览模型文件""" """浏览数据集文件夹(支持多选)"""
file_path, _ = QtWidgets.QFileDialog.getOpenFileName( # 使用简单的单选文件夹对话框,多次选择来实现多选效果
folder_path = QtWidgets.QFileDialog.getExistingDirectory(
self, self,
"选择基础模型", "选择数据集文件夹",
"database/model", "database/dataset" if os.path.exists("database/dataset") else ""
"模型文件 (*.pt *.dat);;所有文件 (*.*)"
) )
if file_path:
self.base_model_edit.setText(file_path)
def _browseDataset(self): if folder_path:
"""浏览数据集文件""" # 获取当前已有的文件夹路径
file_path, _ = QtWidgets.QFileDialog.getOpenFileName( current_paths = self.dataset_paths_edit.text().strip()
existing_folders = [p.strip() for p in current_paths.split(';') if p.strip()] if current_paths else []
# 检查是否已经添加过这个文件夹
if folder_path not in existing_folders:
# 添加新文件夹
if existing_folders:
# 如果已有文件夹,用分号连接
new_paths = current_paths + ';' + folder_path
else:
# 如果是第一个文件夹
new_paths = folder_path
self.dataset_paths_edit.setText(new_paths)
print(f"[TrainingPage] 添加数据集文件夹: {folder_path}")
else:
# 使用style_manager中的对话框管理器显示提示
try:
from ..style_manager import DialogManager
DialogManager.show_information(
self, "提示",
f"文件夹已存在:\n{folder_path}"
)
except ImportError:
QtWidgets.QMessageBox.information(
self, "提示",
f"文件夹已存在:\n{folder_path}"
)
def _onDatasetPathsChanged(self):
"""数据集路径文本变化时的处理"""
# 更新内部数据集列表和隐藏的路径字段
paths_text = self.dataset_paths_edit.text().strip()
folders = [p.strip() for p in paths_text.split(';') if p.strip()]
# 更新内部列表(用于兼容现有逻辑)
self.dataset_folders_list.clear()
for folder in folders:
self.dataset_folders_list.addItem(folder)
# 更新隐藏的路径字段
self.save_liquid_data_path_edit.setText(paths_text)
# 调试信息
print(f"[TrainingPage] 数据集路径更新: {len(folders)} 个文件夹")
for i, folder in enumerate(folders):
print(f" [{i+1}] {folder}")
def _addDatasetFolder(self):
"""添加数据集文件夹(保留兼容性)"""
self._onBrowseDatasets()
def _removeSelectedDatasets(self):
"""删除选中的数据集文件夹"""
selected_items = self.dataset_folders_list.selectedItems()
if not selected_items:
QtWidgets.QMessageBox.information(self, "提示", "请先选择要删除的文件夹")
return
for item in selected_items:
row = self.dataset_folders_list.row(item)
self.dataset_folders_list.takeItem(row)
self._updateDatasetPath()
def _clearAllDatasets(self):
"""清空所有数据集文件夹"""
if self.dataset_folders_list.count() == 0:
return
reply = QtWidgets.QMessageBox.question(
self, self,
"选择数据集配置", "确认清空",
"database/dataset", f"确定要清空所有 {self.dataset_folders_list.count()} 个数据集文件夹吗?",
"YAML文件 (*.yaml *.yml);;所有文件 (*.*)" QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
QtWidgets.QMessageBox.No
) )
if file_path:
self.save_liquid_data_path_edit.setText(file_path) if reply == QtWidgets.QMessageBox.Yes:
self.dataset_folders_list.clear()
self._updateDatasetPath()
def _updateDatasetPath(self):
"""更新隐藏的数据集路径字段(用分号分隔多个文件夹)"""
folders = [self.dataset_folders_list.item(i).text()
for i in range(self.dataset_folders_list.count())]
# 使用分号分隔多个文件夹路径
self.save_liquid_data_path_edit.setText(';'.join(folders))
def getDatasetFolders(self):
"""获取所有数据集文件夹路径列表"""
paths_text = self.dataset_paths_edit.text().strip()
return [p.strip() for p in paths_text.split(';') if p.strip()]
def getTrainingNotes(self):
"""获取训练笔记内容"""
return self._training_notes_content.strip()
def setTrainingNotes(self, notes):
"""设置训练笔记内容"""
self._training_notes_content = notes if notes else ""
self._updateNotesButtonText()
def clearTrainingNotes(self):
"""清空训练笔记(不弹确认框)"""
self._training_notes_content = ""
self._updateNotesButtonText()
def _openNotesDialog(self):
"""打开训练笔记编辑对话框"""
dialog = TrainingNotesDialog(self._training_notes_content, self)
if dialog.exec_() == QtWidgets.QDialog.Accepted:
self._training_notes_content = dialog.getNotesContent()
self._updateNotesButtonText()
def _updateNotesButtonText(self):
"""更新笔记按钮的显示文本"""
if self._training_notes_content.strip():
# 显示笔记的前几个字符
preview = self._training_notes_content.strip()[:8] # 减少字符数以适应按钮宽度
if len(self._training_notes_content.strip()) > 8:
preview += "..."
new_text = f"训练笔记 ({preview})"
# 使用全局样式管理器更新按钮文本和大小
TextButtonStyleManager.updateButtonText(self.training_notes_btn, new_text)
# 添加有内容的视觉提示(保持全局样式基础上的微调)
current_style = self.training_notes_btn.styleSheet()
self.training_notes_btn.setStyleSheet(current_style + """
QPushButton {
background-color: #e3f2fd;
border: 1px solid #2196f3;
}
""")
else:
# 使用全局样式管理器重置按钮
TextButtonStyleManager.updateButtonText(self.training_notes_btn, "训练笔记")
def enableNotesButtons(self):
"""启用笔记按钮(训练完成后调用)"""
# 训练笔记按钮始终可用,无需特殊处理
pass
def disableNotesButtons(self):
"""禁用笔记按钮(训练开始前调用)"""
# 训练笔记按钮始终可用,无需特殊处理
pass
@QtCore.Slot(str) @QtCore.Slot(str)
def appendLog(self, text): def appendLog(self, text):
...@@ -898,10 +1320,11 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -898,10 +1320,11 @@ class TrainingPage(QtWidgets.QWidget):
self.stop_train_btn.setEnabled(is_training) self.stop_train_btn.setEnabled(is_training)
# 禁用参数输入 # 禁用参数输入
self.base_model_edit.setEnabled(not is_training) self.base_model_combo.setEnabled(not is_training)
self.browse_model_btn.setEnabled(not is_training) self.dataset_folders_list.setEnabled(not is_training)
self.save_liquid_data_path_edit.setEnabled(not is_training) self.add_dataset_btn.setEnabled(not is_training)
self.browse_data_btn.setEnabled(not is_training) self.remove_dataset_btn.setEnabled(not is_training)
self.clear_dataset_btn.setEnabled(not is_training)
self.exp_name_edit.setEnabled(not is_training) self.exp_name_edit.setEnabled(not is_training)
self.epochs_spin.setEnabled(not is_training) self.epochs_spin.setEnabled(not is_training)
self.batch_spin.setEnabled(not is_training) self.batch_spin.setEnabled(not is_training)
...@@ -915,8 +1338,59 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -915,8 +1338,59 @@ class TrainingPage(QtWidgets.QWidget):
def showEvent(self, event): def showEvent(self, event):
"""页面显示时刷新模型列表和测试文件列表(确保与模型集管理页面同步)""" """页面显示时刷新模型列表和测试文件列表(确保与模型集管理页面同步)"""
super(TrainingPage, self).showEvent(event) super(TrainingPage, self).showEvent(event)
self._loadBaseModelOptions() # 🔥 加载基础模型列表
self._loadTestModelOptions() self._loadTestModelOptions()
self._loadTestFileList() # 🔥 刷新测试文件列表 # self._loadTestFileList() # 🔥 不再需要刷新测试文件列表(改用浏览方式)
def _loadBaseModelOptions(self):
"""从模型集管理页面加载基础模型选项"""
# 清空现有选项
self.base_model_combo.clear()
try:
# 尝试从父窗口获取模型集页面
if hasattr(self._parent, 'modelSetPage'):
model_set_page = self._parent.modelSetPage
# 获取模型参数字典
if hasattr(model_set_page, '_model_params'):
model_params = model_set_page._model_params
if not model_params:
self.base_model_combo.addItem("未找到模型", None)
return
# 获取默认模型
default_model = None
if hasattr(model_set_page, '_current_default_model'):
default_model = model_set_page._current_default_model
# 添加所有模型到下拉框
default_index = 0
for idx, (model_name, params) in enumerate(model_params.items()):
model_path = params.get('path', '')
# 构建显示名称
display_name = model_name
if model_name == default_model:
display_name = f"{model_name} (默认)"
default_index = idx
# 添加到下拉框,使用模型路径作为数据
self.base_model_combo.addItem(display_name, model_path)
# 设置默认选择
self.base_model_combo.setCurrentIndex(default_index)
else:
self.base_model_combo.addItem("模型集页面未初始化", None)
else:
self.base_model_combo.addItem("未找到模型集页面", None)
except Exception as e:
print(f"[基础模型] 加载失败: {e}")
import traceback
traceback.print_exc()
self.base_model_combo.addItem("加载失败", None)
def _loadTestModelOptions(self): def _loadTestModelOptions(self):
"""加载测试模型选项(从 train_model 文件夹读取)""" """加载测试模型选项(从 train_model 文件夹读取)"""
...@@ -1118,92 +1592,11 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -1118,92 +1592,11 @@ class TrainingPage(QtWidgets.QWidget):
return models return models
def _loadTestFileList(self): def _loadTestFileList(self):
"""加载测试文件列表到下拉框""" """加载测试文件列表(保留方法以保持向后兼容,但现在使用浏览方式)"""
try: # 该方法现已改为使用浏览按钮选择文件,不再从固定目录加载
import os # 保留此方法以防其他代码调用
from pathlib import Path
# 🔥 改进:使用多种方式获取project_root
project_root = None
try:
from ...database.config import get_project_root
project_root = get_project_root()
# print(f"[测试文件] 使用相对导入获取project_root: {project_root}")
except (ImportError, ValueError):
# 相对导入失败是预期行为(作为主程序运行时),静默处理
try:
from database.config import get_project_root
project_root = get_project_root()
except (ImportError, ValueError):
# 备选方案:使用当前文件的父目录
current_file = Path(__file__).resolve()
project_root = current_file.parent.parent.parent
if not project_root:
raise RuntimeError("无法获取project_root")
test_file_dir = os.path.join(str(project_root), 'database', 'model_test_file')
# 清空下拉框
self.test_file_input.clear()
if not os.path.exists(test_file_dir):
self.test_file_input.addItem("(目录不存在)")
return
# 扫描目录
items = []
try:
dir_items = os.listdir(test_file_dir)
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv']
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
for item in dir_items:
item_path = os.path.join(test_file_dir, item)
# 检查是否为视频文件
if os.path.isfile(item_path):
if any(item.lower().endswith(ext) for ext in video_extensions):
items.append((item, item_path))
# 检查是否为文件夹
elif os.path.isdir(item_path):
# 检查文件夹内是否有图片或视频
has_images = False
has_videos = False
try:
folder_items = os.listdir(item_path)
for file in folder_items:
file_lower = file.lower()
if any(file_lower.endswith(ext) for ext in image_extensions):
has_images = True
break
elif any(file_lower.endswith(ext) for ext in video_extensions):
has_videos = True
except Exception as folder_e:
pass pass
# 🔥 改进:文件夹包含图片或视频都添加
if has_images or has_videos:
items.append((f"📁 {item}", item_path))
except Exception as e:
import traceback
traceback.print_exc()
# 添加到下拉框
if items:
for display_name, full_path in sorted(items):
self.test_file_input.addItem(display_name, full_path)
else:
self.test_file_input.addItem("(未找到测试文件)")
except Exception as e:
import traceback
traceback.print_exc()
self.test_file_input.addItem("(加载失败)")
def _getRememberedTestModel(self, project_root): def _getRememberedTestModel(self, project_root):
"""从配置文件获取记忆的测试模型路径""" """从配置文件获取记忆的测试模型路径"""
try: try:
...@@ -1254,8 +1647,8 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -1254,8 +1647,8 @@ class TrainingPage(QtWidgets.QWidget):
def getTestFilePath(self): def getTestFilePath(self):
"""获取选中的测试文件路径""" """获取选中的测试文件路径"""
# 🔥 改为从QComboBox获取数据 # 从 QLineEdit 获取文件路径
return self.test_file_input.currentData() or "" return self.test_file_input.text().strip()
def isTestingInProgress(self): def isTestingInProgress(self):
"""检查是否正在测试中""" """检查是否正在测试中"""
...@@ -1267,27 +1660,14 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -1267,27 +1660,14 @@ class TrainingPage(QtWidgets.QWidget):
if is_testing: if is_testing:
# 切换为"停止测试"状态 # 切换为"停止测试"状态
self.start_test_btn.setText("停止测试") TextButtonStyleManager.updateButtonText(self.start_test_btn, "停止测试")
# 使用红色样式表示危险操作 # 使用全局样式管理器的危险按钮样式
self.start_test_btn.setStyleSheet(""" TextButtonStyleManager.applyDangerStyle(self.start_test_btn)
QPushButton {
background-color: #dc3545;
color: white;
border: none;
padding: 8px 16px;
border-radius: 4px;
font-weight: bold;
min-width: 100px;
}
QPushButton:hover {
background-color: #c82333;
}
""")
else: else:
# 切换为"开始测试"状态 # 切换为"开始测试"状态
self.start_test_btn.setText("开始测试") TextButtonStyleManager.updateButtonText(self.start_test_btn, "开始测试")
# 恢复默认样式 # 恢复标准样式
self.start_test_btn.setStyleSheet("") TextButtonStyleManager.applyStandardStyle(self.start_test_btn)
def _onTemplateChecked(self, button): def _onTemplateChecked(self, button):
"""处理模板复选框选中事件""" """处理模板复选框选中事件"""
...@@ -1326,8 +1706,14 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -1326,8 +1706,14 @@ class TrainingPage(QtWidgets.QWidget):
try: try:
# 基础模型 # 基础模型
if 'model' in config: if 'model' in config:
self.base_model_edit.setText(str(config['model'])) model_path = str(config['model'])
print(f"[模板] 设置基础模型: {config['model']}") # 在下拉菜单中查找匹配的模型
for i in range(self.base_model_combo.count()):
item_data = self.base_model_combo.itemData(i)
if item_data and item_data == model_path:
self.base_model_combo.setCurrentIndex(i)
print(f"[模板] 设置基础模型: {model_path}")
break
# 数据集配置 # 数据集配置
if 'data' in config: if 'data' in config:
...@@ -1392,6 +1778,134 @@ class TrainingPage(QtWidgets.QWidget): ...@@ -1392,6 +1778,134 @@ class TrainingPage(QtWidgets.QWidget):
import traceback import traceback
traceback.print_exc() traceback.print_exc()
def _browseTestFile(self):
"""浏览选择测试文件(支持图片和视频)"""
try:
# 定义支持的文件类型
image_formats = "图片文件 (*.jpg *.jpeg *.png *.bmp *.tiff *.webp)"
video_formats = "视频文件 (*.mp4 *.avi *.mov *.mkv *.flv *.wmv)"
all_formats = "所有支持的文件 (*.jpg *.jpeg *.png *.bmp *.tiff *.webp *.mp4 *.avi *.mov *.mkv *.flv *.wmv)"
# 构建文件过滤器
file_filter = f"{all_formats};;{image_formats};;{video_formats};;所有文件 (*.*)"
# 打开文件选择对话框
file_path, _ = QtWidgets.QFileDialog.getOpenFileName(
self,
"选择测试图片或视频文件",
"", # 默认目录为空,使用系统默认
file_filter
)
# 如果用户选择了文件,则设置到输入框
if file_path:
self.test_file_input.setText(file_path)
print(f"[测试文件] 已选择: {file_path}")
except Exception as e:
import traceback
traceback.print_exc()
QtWidgets.QMessageBox.warning(
self,
"文件选择失败",
f"选择测试文件时发生错误:\n{str(e)}"
)
def _onViewCurveClicked(self):
"""查看曲线按钮点击处理"""
try:
# 检查是否有曲线数据
if not hasattr(self, 'curve_data_x') or not hasattr(self, 'curve_data_y'):
self._showNoCurveMessage()
return
if len(self.curve_data_x) == 0:
self._showNoCurveMessage()
return
# 切换到曲线面板显示
self.showCurvePanel()
# 显示曲线信息提示
data_count = len(self.curve_data_x)
if data_count == 1:
# 图片测试
liquid_level = self.curve_data_y[0]
QtWidgets.QMessageBox.information(
self,
"曲线信息",
f"图片测试结果:\n液位高度: {liquid_level:.1f} mm\n\n"
f"曲线已显示在左侧面板中。"
)
else:
# 视频测试
min_level = min(self.curve_data_y)
max_level = max(self.curve_data_y)
avg_level = sum(self.curve_data_y) / len(self.curve_data_y)
QtWidgets.QMessageBox.information(
self,
"曲线信息",
f"视频测试结果:\n"
f"数据点数: {data_count} 个\n"
f"液位范围: {min_level:.1f} - {max_level:.1f} mm\n"
f"平均液位: {avg_level:.1f} mm\n\n"
f"曲线已显示在左侧面板中。"
)
except Exception as e:
print(f"[查看曲线] 显示曲线失败: {e}")
QtWidgets.QMessageBox.warning(
self,
"显示失败",
f"显示曲线时发生错误:\n{str(e)}"
)
def _showNoCurveMessage(self):
"""显示无曲线数据的提示"""
QtWidgets.QMessageBox.information(
self,
"无曲线数据",
"当前没有可显示的曲线数据。\n\n"
"请先进行模型测试:\n"
"1. 选择测试模型\n"
"2. 选择测试文件\n"
"3. 点击\"开始标注\"\n"
"4. 点击\"开始测试\"\n\n"
"测试完成后即可查看曲线结果。"
)
def enableViewCurveButton(self):
"""启用查看曲线按钮(测试完成后调用)"""
try:
self.view_curve_btn.setEnabled(True)
self.view_curve_btn.setToolTip("点击查看测试结果曲线")
# 检查曲线数据类型并更新按钮文本
if hasattr(self, 'curve_data_x') and len(self.curve_data_x) > 0:
data_count = len(self.curve_data_x)
if data_count == 1:
self.view_curve_btn.setText("查看曲线(图)")
else:
self.view_curve_btn.setText("查看曲线(视)")
else:
self.view_curve_btn.setText("查看曲线")
print(f"[查看曲线] 按钮已启用,数据点数: {len(self.curve_data_x) if hasattr(self, 'curve_data_x') else 0}")
except Exception as e:
print(f"[查看曲线] 启用按钮失败: {e}")
def disableViewCurveButton(self):
"""禁用查看曲线按钮(测试开始前调用)"""
try:
self.view_curve_btn.setEnabled(False)
self.view_curve_btn.setText("查看曲线")
self.view_curve_btn.setToolTip("测试完成后可查看曲线结果")
print(f"[查看曲线] 按钮已禁用")
except Exception as e:
print(f"[查看曲线] 禁用按钮失败: {e}")
def getTemplateConfig(self): def getTemplateConfig(self):
"""获取当前选中的模板配置名称""" """获取当前选中的模板配置名称"""
if self.checkbox_template_1.isChecked(): if self.checkbox_template_1.isChecked():
...@@ -1414,3 +1928,148 @@ if __name__ == '__main__': ...@@ -1414,3 +1928,148 @@ if __name__ == '__main__':
page.show() page.show()
sys.exit(app.exec_()) sys.exit(app.exec_())
class TrainingNotesDialog(QtWidgets.QDialog):
"""训练笔记编辑对话框"""
def __init__(self, initial_content="", parent=None):
super().__init__(parent)
self.setWindowTitle("训练笔记编辑")
self.setMinimumSize(500, 400)
self.resize(600, 500)
# 设置窗口图标(如果有的话)
self.setWindowFlags(self.windowFlags() & ~QtCore.Qt.WindowContextHelpButtonHint)
# 应用全局样式管理器
from ..style_manager import FontManager, BackgroundStyleManager
BackgroundStyleManager.applyToWidget(self)
self._setupUI()
self.text_edit.setPlainText(initial_content)
# 应用全局字体管理器到整个对话框
FontManager.applyToDialog(self)
def _setupUI(self):
"""设置UI界面"""
layout = QtWidgets.QVBoxLayout(self)
layout.setSpacing(10)
layout.setContentsMargins(15, 15, 15, 15)
# 标题标签(使用全局字体管理器)
from ..style_manager import FontManager
title_label = QtWidgets.QLabel("训练笔记")
title_label.setFont(FontManager.getTitleFont())
title_label.setStyleSheet("""
QLabel {
color: #333;
margin-bottom: 5px;
}
""")
layout.addWidget(title_label)
# 说明文字(使用全局字体管理器)
info_label = QtWidgets.QLabel("在此记录本次训练的相关信息,如数据集变化、参数调整原因、预期效果等...")
info_label.setFont(FontManager.getSmallFont())
info_label.setStyleSheet("""
QLabel {
color: #666;
margin-bottom: 10px;
}
""")
info_label.setWordWrap(True)
layout.addWidget(info_label)
# 文本编辑区域(使用全局字体管理器)
self.text_edit = QtWidgets.QTextEdit()
self.text_edit.setFont(FontManager.getMediumFont())
self.text_edit.setStyleSheet("""
QTextEdit {
border: 1px solid #ccc;
background-color: white;
padding: 10px;
line-height: 1.4;
}
""")
self.text_edit.setPlaceholderText(
"示例内容:\n"
"• 数据集:新增了100张液位图片\n"
"• 参数调整:学习率从0.01调整为0.005\n"
"• 预期效果:提高小液位目标的检测精度\n"
"• 其他备注:..."
)
layout.addWidget(self.text_edit)
# 字符计数标签(使用全局字体管理器)
self.char_count_label = QtWidgets.QLabel("字符数: 0")
self.char_count_label.setFont(FontManager.getSmallFont())
self.char_count_label.setStyleSheet("color: #666;")
self.char_count_label.setAlignment(QtCore.Qt.AlignRight)
layout.addWidget(self.char_count_label)
# 按钮区域(使用全局样式管理器)
button_layout = QtWidgets.QHBoxLayout()
button_layout.addStretch()
# 清空按钮
clear_btn = TextButtonStyleManager.createStandardButton("清空", self, self._clearText)
button_layout.addWidget(clear_btn)
# 取消按钮
cancel_btn = TextButtonStyleManager.createStandardButton("取消", self, self.reject)
button_layout.addWidget(cancel_btn)
# 保存按钮(使用全局样式管理器创建,然后添加特殊样式)
save_btn = TextButtonStyleManager.createStandardButton("保存", self, self.accept)
save_btn.setDefault(True)
save_btn.setStyleSheet("""
QPushButton {
background-color: #2196f3;
color: white;
border: none;
padding: 8px;
font-weight: bold;
}
QPushButton:hover {
background-color: #1976d2;
}
""")
button_layout.addWidget(save_btn)
layout.addLayout(button_layout)
# 连接信号
self.text_edit.textChanged.connect(self._updateCharCount)
self._updateCharCount()
def _clearText(self):
"""清空文本"""
if self.text_edit.toPlainText().strip():
from ..style_manager import DialogManager
if DialogManager.show_question_warning(
self,
"确认清空",
"确定要清空所有笔记内容吗?",
"是", "否"
):
self.text_edit.clear()
def _updateCharCount(self):
"""更新字符计数"""
text = self.text_edit.toPlainText()
char_count = len(text)
self.char_count_label.setText(f"字符数: {char_count}")
# 字符数过多时显示警告颜色(保持全局字体设置)
if char_count > 1000:
self.char_count_label.setStyleSheet("color: #f44336;")
elif char_count > 500:
self.char_count_label.setStyleSheet("color: #ff9800;")
else:
self.char_count_label.setStyleSheet("color: #666;")
def getNotesContent(self):
"""获取笔记内容"""
return self.text_edit.toPlainText().strip()
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
根据屏幕分辨率自动调整UI尺寸 根据屏幕分辨率自动调整UI尺寸
""" """
from PyQt5 import QtWidgets, QtCore from qtpy import QtWidgets, QtCore
class ResponsiveLayout: class ResponsiveLayout:
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
统一样式管理器 统一样式管理器
...@@ -40,7 +40,7 @@ import os.path as osp ...@@ -40,7 +40,7 @@ import os.path as osp
import sys import sys
try: try:
from PyQt5 import QtGui, QtWidgets, QtCore from qtpy import QtGui, QtWidgets, QtCore
except ImportError as e: except ImportError as e:
raise raise
...@@ -78,12 +78,11 @@ class FontManager: ...@@ -78,12 +78,11 @@ class FontManager:
@staticmethod @staticmethod
def getDefaultFont(): def getDefaultFont():
"""获取默认字体""" """获取默认字体"""
font = QtGui.QFont( return FontManager.getFont(
FontManager.DEFAULT_FONT_FAMILY, size=FontManager.DEFAULT_FONT_SIZE,
FontManager.DEFAULT_FONT_SIZE weight=FontManager.DEFAULT_FONT_WEIGHT,
family=FontManager.DEFAULT_FONT_FAMILY
) )
font.setWeight(FontManager.DEFAULT_FONT_WEIGHT)
return font
@staticmethod @staticmethod
def getFont(size=None, weight=None, italic=False, underline=False, family=None): def getFont(size=None, weight=None, italic=False, underline=False, family=None):
...@@ -96,7 +95,31 @@ class FontManager: ...@@ -96,7 +95,31 @@ class FontManager:
weight = FontManager.DEFAULT_FONT_WEIGHT weight = FontManager.DEFAULT_FONT_WEIGHT
font = QtGui.QFont(family, size) font = QtGui.QFont(family, size)
# PySide6兼容性:将整数权重转换为QFont.Weight枚举
try:
# 尝试使用PySide6的Weight枚举
if hasattr(QtGui.QFont, 'Weight'):
if weight <= 25:
font.setWeight(QtGui.QFont.Weight.Light)
elif weight <= 50:
font.setWeight(QtGui.QFont.Weight.Normal)
elif weight <= 63:
font.setWeight(QtGui.QFont.Weight.DemiBold)
elif weight <= 75:
font.setWeight(QtGui.QFont.Weight.Bold)
else:
font.setWeight(QtGui.QFont.Weight.Black)
else:
# 回退到旧的整数方式(PyQt5)
font.setWeight(weight) font.setWeight(weight)
except (AttributeError, TypeError):
# 如果出现任何错误,使用默认权重
try:
font.setWeight(QtGui.QFont.Weight.Normal)
except:
pass
font.setItalic(italic) font.setItalic(italic)
font.setUnderline(underline) font.setUnderline(underline)
return font return font
...@@ -157,8 +180,17 @@ class FontManager: ...@@ -157,8 +180,17 @@ class FontManager:
@staticmethod @staticmethod
def applyToApplication(app): def applyToApplication(app):
"""应用默认字体到整个应用程序""" """应用默认字体到整个应用程序"""
font = FontManager.getDefaultFont() try:
# 确保使用当前 Qt 后端创建字体对象
from qtpy import QtGui
font = QtGui.QFont(
FontManager.DEFAULT_FONT_FAMILY,
FontManager.DEFAULT_FONT_SIZE
)
font.setWeight(FontManager.DEFAULT_FONT_WEIGHT)
app.setFont(font) app.setFont(font)
except Exception as e:
print(f"字体设置失败,使用默认字体: {e}")
@staticmethod @staticmethod
def applyToWidgetRecursive(widget, size=None, weight=None): def applyToWidgetRecursive(widget, size=None, weight=None):
...@@ -827,7 +859,7 @@ class TextButtonStyleManager: ...@@ -827,7 +859,7 @@ class TextButtonStyleManager:
# 计算文本宽度:字符数 * 每字符宽度 + 内边距 # 计算文本宽度:字符数 * 每字符宽度 + 内边距
text_width = len(text) * cls.CHAR_WIDTH + cls.MIN_PADDING text_width = len(text) * cls.CHAR_WIDTH + cls.MIN_PADDING
# 返回基础宽度和计算宽度的较大值
return max(cls.BASE_WIDTH, text_width) return max(cls.BASE_WIDTH, text_width)
@classmethod @classmethod
...@@ -850,7 +882,14 @@ class TextButtonStyleManager: ...@@ -850,7 +882,14 @@ class TextButtonStyleManager:
@classmethod @classmethod
def createStandardButton(cls, text, parent=None, slot=None): def createStandardButton(cls, text, parent=None, slot=None):
"""创建标准样式的文本按钮""" """创建标准样式的文本按钮"""
button = QtWidgets.QPushButton(text, parent) # 修复 PySide6 兼容性问题 - 分步创建
try:
button = QtWidgets.QPushButton(text)
if parent is not None:
button.setParent(parent)
except Exception as e:
print(f"按钮创建失败: {e}")
button = QtWidgets.QPushButton("按钮")
# 应用标准样式 # 应用标准样式
cls.applyToButton(button, text) cls.applyToButton(button, text)
...@@ -867,6 +906,37 @@ class TextButtonStyleManager: ...@@ -867,6 +906,37 @@ class TextButtonStyleManager:
button.setText(new_text) button.setText(new_text)
cls.applyToButton(button, new_text) cls.applyToButton(button, new_text)
@classmethod
def applyDangerStyle(cls, button):
"""应用危险按钮样式(红色)"""
button.setStyleSheet("""
QPushButton {
background-color: #dc3545;
color: white;
border: none;
padding: 8px 16px;
border-radius: 4px;
font-weight: bold;
min-width: 100px;
}
QPushButton:hover {
background-color: #c82333;
}
QPushButton:pressed {
background-color: #bd2130;
}
QPushButton:disabled {
background-color: #6c757d;
color: #ffffff;
}
""")
@classmethod
def applyStandardStyle(cls, button):
"""应用标准按钮样式"""
# 重新应用标准样式
cls.applyToButton(button)
class BackgroundStyleManager: class BackgroundStyleManager:
"""全局背景颜色管理器""" """全局背景颜色管理器"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment