Commit da148b05 by yhb

123456

parent 7cf1afc7
......@@ -29,10 +29,11 @@ print("[环境变量] OpenMP冲突已修复")
from qtpy import QtWidgets
from app import MainWindow
# 延迟导入,避免在 QApplication 创建前创建 QWidget
# from app import MainWindow
from database.config import get_config, get_temp_models_dir
from widgets.style_manager import FontManager
from widgets.responsive_layout import ResponsiveLayout
# from widgets.style_manager import FontManager
# from widgets.responsive_layout import ResponsiveLayout
def setup_logging(level: str = "info"):
......@@ -270,6 +271,11 @@ def _main():
app.setApplicationName('Detection')
app.setOrganizationName('Detection')
# 在 QApplication 创建后导入可能创建 QWidget 的模块
from app import MainWindow
from widgets.style_manager import FontManager
from widgets.responsive_layout import ResponsiveLayout
# 初始化响应式布局系统
ResponsiveLayout.initialize(app)
......
......@@ -686,6 +686,10 @@ class MainWindow(
from widgets import ChannelPanel, MissionPanel, CurvePanel
# 主页面容器
# 修复:确保 QApplication 完全初始化后再创建 QWidget
app = QtWidgets.QApplication.instance()
if app is None:
raise RuntimeError("QApplication 未初始化")
page = QtWidgets.QWidget()
page_layout = QtWidgets.QVBoxLayout(page)
page_layout.setContentsMargins(0, 0, 0, 0)
......@@ -714,6 +718,10 @@ class MainWindow(
except ImportError:
from widgets import ChannelPanel, MissionPanel
# 确保 QApplication 存在
app = QtWidgets.QApplication.instance()
if app is None:
raise RuntimeError("QApplication 未初始化")
layout_widget = QtWidgets.QWidget()
main_layout = QtWidgets.QHBoxLayout(layout_widget)
main_layout.setContentsMargins(10, 10, 10, 10)
......
......@@ -283,15 +283,22 @@ class ModelTestHandler:
def _handleStartTestExecution(self):
"""执行开始测试操作 - 液位检测测试功能"""
try:
# 🔥 清空曲线数据,准备新的测试
self._clearCurve()
# 禁用查看曲线按钮
if hasattr(self.training_panel, 'disableViewCurveButton'):
self.training_panel.disableViewCurveButton()
# 切换按钮状态为“停止测试”
self.training_panel.setTestButtonState(True)
# 获取选择的测试模型和测试文件
test_model_display = self.training_panel.test_model_combo.currentText()
test_model_path_raw = self.training_panel.test_model_combo.currentData()
# 改为从QComboBox获取数据
test_file_path_raw = self.training_panel.test_file_input.currentData() or ""
test_file_display = self.training_panel.test_file_input.currentText()
# 从QLineEdit获取测试文件路径(浏览选择的文件)
test_file_path_raw = self.training_panel.test_file_input.text().strip()
test_file_display = os.path.basename(test_file_path_raw) if test_file_path_raw else ""
# 关键修复:路径规范化处理,确保相对路径转换为绝对路径
project_root = get_project_root()
......@@ -342,7 +349,7 @@ class ModelTestHandler:
<p style="margin: 0; font-size: 12px; color: #ffffff;"><strong>解决方法:</strong></p>
<ul style="margin: 5px 0; padding-left: 20px; font-size: 12px; color: #ffffff;">
<li>请在上方下拉框中选择测试模型</li>
<li>请在上方下拉框中选择测试文件</li>
<li>请点击"浏览..."按钮选择测试图片或视频文件</li>
<li>确保选择的文件存在且可访问</li>
</ul>
</div>
......@@ -353,7 +360,7 @@ class ModelTestHandler:
QtWidgets.QMessageBox.warning(
self.training_panel,
"参数缺失",
error_msg + "请在上方下拉框中选择测试模型和测试文件"
error_msg + '请在上方下拉框中选择测试模型,并点击"浏览..."按钮选择测试文件'
)
return
......@@ -541,11 +548,24 @@ class ModelTestHandler:
if detection_result:
# 显示检测结果
self._showDetectionResult(detection_result)
# 🔥 添加曲线数据点
if 'liquid_level_mm' in detection_result:
liquid_level = detection_result['liquid_level_mm']
# 图片测试只有一个数据点,帧序号为0
self._addCurveDataPoint(0, liquid_level)
# 显示曲线面板
if hasattr(self.training_panel, 'showCurvePanel'):
self.training_panel.showCurvePanel()
# 启用查看曲线按钮
if hasattr(self.training_panel, 'enableViewCurveButton'):
self.training_panel.enableViewCurveButton()
QtWidgets.QMessageBox.information(
self.training_panel,
"测试完成",
"模型测试已成功完成!"
"模型测试已成功完成!可查看曲线分析结果。"
)
# 恢复按钮状态
......@@ -585,6 +605,29 @@ class ModelTestHandler:
self._test_thread.wait()
self._test_thread = None
def _addCurveDataPoint(self, frame_index, height_mm):
"""添加曲线数据点
Args:
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
try:
if hasattr(self.training_panel, 'addCurvePoint'):
self.training_panel.addCurvePoint(frame_index, height_mm)
print(f"[曲线] 添加数据点: 帧{frame_index}, 液位{height_mm:.1f}mm")
except Exception as e:
print(f"[曲线] 添加数据点失败: {e}")
def _clearCurve(self):
"""清空曲线数据"""
try:
if hasattr(self.training_panel, '_clearCurve'):
self.training_panel._clearCurve()
print(f"[曲线] 已清空曲线")
except Exception as e:
print(f"[曲线] 清空曲线失败: {e}")
def _showDetectionResult(self, detection_result):
"""显示检测结果"""
try:
......@@ -592,12 +635,16 @@ class ModelTestHandler:
# 这里可以添加结果显示逻辑
# 例如在display_panel中显示检测结果
if hasattr(self.training_panel, 'display_panel') and detection_result:
# 提取液位高度
liquid_level = detection_result.get('liquid_level_mm', 0)
result_html = f"""
<div style="padding: 15px; background: #000000; border: 1px solid #28a745; border-radius: 5px; color: #ffffff;">
<h3 style="margin-top: 0; color: #28a745;">液位检测测试成功</h3>
<p style="color: #ffffff;"><strong>检测结果:</strong> 已完成液位检测</p>
<p style="color: #ffffff;"><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
<div style="margin-top: 15px; padding: 10px; background: #1a1a1a; border-radius: 3px;">
<p style="margin: 0; font-size: 12px; color: #ffffff;">检测结果已完成,可以在结果面板中查看详细信息。</p>
<p style="margin: 0; font-size: 12px; color: #ffffff;">检测结果已完成,可以在曲线面板中查看详细分析。</p>
</div>
</div>
"""
......@@ -1197,6 +1244,9 @@ class ModelTestHandler:
success_count = 0
fail_count = 0
# 🔥 清空曲线数据,准备添加新的视频检测曲线
self._clearCurve()
# 关闭进度对话框
if progress_dialog:
progress_dialog.setLabelText("正在检测中...")
......@@ -1245,6 +1295,13 @@ class ModelTestHandler:
last_detection_result = detection_result
detection_count += 1
success_count += 1
# 🔥 添加曲线数据点(取第一个区域的液位高度)
if detection_result and len(detection_result) > 0:
first_area_result = detection_result[0]
if 'liquid_level_mm' in first_area_result:
liquid_level = first_area_result['liquid_level_mm']
self._addCurveDataPoint(frame_index, liquid_level)
except Exception as e:
print(f"[视频检测] 第 {frame_index} 帧检测失败: {e}")
fail_count += 1
......@@ -1317,6 +1374,15 @@ class ModelTestHandler:
if not self._detection_stopped:
print(f"[视频检测] 显示检测结果视频...")
self._showDetectionVideo(output_video_path, frame_index, detection_count, success_count, fail_count)
# 🔥 显示曲线面板
if hasattr(self.training_panel, 'showCurvePanel'):
self.training_panel.showCurvePanel()
print(f"[曲线] 曲线面板已显示,共{len(self.training_panel.curve_data_x if hasattr(self.training_panel, 'curve_data_x') else [])}个数据点")
# 启用查看曲线按钮
if hasattr(self.training_panel, 'enableViewCurveButton'):
self.training_panel.enableViewCurveButton()
else:
print(f"[视频检测] 检测被用户停止")
......@@ -1844,6 +1910,9 @@ class ModelTestHandler:
"生成文件:",
f" 结果视频: {result_video_filename}",
f" 测试报告: {report_filename}",
f" JSON结果: {json_filename}",
f" 曲线数据: {file_prefix}_curve.csv",
f" 曲线图片: {file_prefix}_curve.png",
"",
"=" * 60,
]
......@@ -1876,7 +1945,9 @@ class ModelTestHandler:
"files": {
"result_video": result_video_filename,
"report": report_filename,
"json_result": json_filename
"json_result": json_filename,
"curve_data_csv": f"{file_prefix}_curve.csv",
"curve_image_png": f"{file_prefix}_curve.png"
}
}
......@@ -1885,6 +1956,23 @@ class ModelTestHandler:
print(f"[保存视频结果] JSON结果已保存: {json_path}")
# 4. 保存曲线数据(CSV格式)和曲线图片
try:
if hasattr(self.training_panel, 'saveCurveData') and hasattr(self.training_panel, 'saveCurveImage'):
# 保存曲线CSV数据
curve_csv_filename = f"{file_prefix}_curve.csv"
curve_csv_path = os.path.join(test_results_dir, curve_csv_filename)
if self.training_panel.saveCurveData(curve_csv_path):
print(f"[保存视频结果] 曲线CSV数据已保存: {curve_csv_path}")
# 保存曲线图片
curve_image_filename = f"{file_prefix}_curve.png"
curve_image_path = os.path.join(test_results_dir, curve_image_filename)
if self.training_panel.saveCurveImage(curve_image_path):
print(f"[保存视频结果] 曲线图片已保存: {curve_image_path}")
except Exception as curve_error:
print(f"[保存视频结果] ⚠️ 曲线保存失败(非致命错误): {curve_error}")
print(f"[保存视频结果] ✅ 所有测试结果已成功保存到: {test_results_dir}")
except Exception as e:
......@@ -1991,6 +2079,9 @@ class ModelTestHandler:
f" 原始图像: {original_filename}",
f" 检测结果: {result_filename}",
f" 测试报告: {report_filename}",
f" JSON结果: {json_filename}",
f" 曲线数据: {file_prefix}_curve.csv",
f" 曲线图片: {file_prefix}_curve.png",
"",
"=" * 60,
]
......@@ -2026,7 +2117,9 @@ class ModelTestHandler:
"original_image": original_filename,
"result_image": result_filename,
"report": report_filename,
"json_result": json_filename
"json_result": json_filename,
"curve_data_csv": f"{file_prefix}_curve.csv",
"curve_image_png": f"{file_prefix}_curve.png"
}
}
......@@ -2035,6 +2128,23 @@ class ModelTestHandler:
print(f"[保存图片结果] JSON结果已保存: {json_path}")
# 5. 保存曲线数据(CSV格式)和曲线图片
try:
if hasattr(self.training_panel, 'saveCurveData') and hasattr(self.training_panel, 'saveCurveImage'):
# 保存曲线CSV数据
curve_csv_filename = f"{file_prefix}_curve.csv"
curve_csv_path = os.path.join(test_results_dir, curve_csv_filename)
if self.training_panel.saveCurveData(curve_csv_path):
print(f"[保存图片结果] 曲线CSV数据已保存: {curve_csv_path}")
# 保存曲线图片
curve_image_filename = f"{file_prefix}_curve.png"
curve_image_path = os.path.join(test_results_dir, curve_image_filename)
if self.training_panel.saveCurveImage(curve_image_path):
print(f"[保存图片结果] 曲线图片已保存: {curve_image_path}")
except Exception as curve_error:
print(f"[保存图片结果] ⚠️ 曲线保存失败(非致命错误): {curve_error}")
print(f"[保存图片结果] ✅ 所有测试结果已成功保存到: {test_results_dir}")
except Exception as e:
......
......@@ -165,7 +165,7 @@ class ModelTrainingHandler(ModelTestHandler):
return False
if not training_params.get('save_liquid_data_path'):
QtWidgets.QMessageBox.critical(self.main_window, "参数错误", "未找到可用的数据集配置文件")
QtWidgets.QMessageBox.critical(self.main_window, "参数错误", "未选择数据集文件夹,请至少添加一个数据集文件夹")
return False
if not training_params.get('exp_name'):
......@@ -208,25 +208,44 @@ class ModelTrainingHandler(ModelTestHandler):
f"基础模型文件不存在\n文件路径: {base_model}\n请检查文件路径是否正确,或重新选择模型文件。")
return False
if not os.path.exists(save_liquid_data_path):
QtWidgets.QMessageBox.critical(self.main_window, "文件错误",
f"数据集配置文件不存在\n文件路径: {save_liquid_data_path}\n请检查文件路径是否正确,或重新选择数据集文件。")
# 解析数据集文件夹列表(用分号分隔)
dataset_folders = [f.strip() for f in save_liquid_data_path.split(';') if f.strip()]
if not dataset_folders:
QtWidgets.QMessageBox.critical(self.main_window, "参数错误", "未选择数据集文件夹,请至少添加一个数据集文件夹")
return False
# 验证每个数据集文件夹是否存在
invalid_folders = []
for folder in dataset_folders:
if not os.path.exists(folder):
invalid_folders.append(folder)
if invalid_folders:
QtWidgets.QMessageBox.critical(
self.main_window,
"文件夹错误",
f"以下数据集文件夹不存在:\n\n" + "\n".join(invalid_folders) +
"\n\n请检查文件夹路径是否正确,或重新选择数据集文件夹。"
)
return False
# 验证数据集配置和内容
validation_result, validation_msg = self._validateTrainingDataWithDetails(save_liquid_data_path)
# 验证数据集文件夹内容
validation_result, validation_msg = self._validateDatasetFolders(dataset_folders)
if not validation_result:
QtWidgets.QMessageBox.critical(
self.main_window,
"数据集验证失败",
f"数据集验证失败:\n\n{validation_msg}\n\n请检查数据集配置和文件。"
f"数据集验证失败:\n\n{validation_msg}\n\n请检查数据集文件夹内容。"
)
return False
# 确认对话框
confirm_msg = f"确定要开始升级模型吗?\n\n"
confirm_msg += f"基础模型: {os.path.basename(base_model)}\n"
confirm_msg += f"数据集: {os.path.basename(save_liquid_data_path)}\n"
confirm_msg += f"数据集文件夹数量: {len(dataset_folders)}\n"
for i, folder in enumerate(dataset_folders, 1):
confirm_msg += f" {i}. {os.path.basename(folder)}\n"
confirm_msg += f"图像尺寸: {training_params['imgsz']}\n"
confirm_msg += f"训练轮数: {training_params['epochs']}\n"
confirm_msg += f"批次大小: {training_params['batch']}\n"
......@@ -271,12 +290,22 @@ class ModelTrainingHandler(ModelTestHandler):
def _startTrainingWorker(self, training_params):
"""启动训练工作线程"""
try:
# 检查是否已有训练在进行中
if self.training_active and self.training_worker:
QtWidgets.QMessageBox.warning(
self.main_window,
"提示",
"训练正在进行中,请先停止当前训练"
)
return False
# 禁止自动下载yolo11模型
os.environ['YOLO_AUTODOWNLOAD'] = '0'
os.environ['YOLO_OFFLINE'] = '1'
# 重置用户停止标记
self._is_user_stopped = False
self._is_stopping = False # 标记训练是否正在停止中
# 如果面板处于"继续训练"模式,切换回"停止升级"模式
if hasattr(self, 'training_panel'):
......@@ -288,6 +317,31 @@ class ModelTrainingHandler(ModelTestHandler):
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_log_text'):
self.training_panel.train_log_text.clear()
# 处理多数据集文件夹合并
dataset_folders = [f.strip() for f in training_params['save_liquid_data_path'].split(';') if f.strip()]
if len(dataset_folders) > 1:
self._appendLog("检测到多个数据集文件夹,正在合并...\n")
merged_data_yaml = self._mergeMultipleDatasets(dataset_folders, training_params['exp_name'])
if merged_data_yaml:
training_params['save_liquid_data_path'] = merged_data_yaml
self._appendLog(f"数据集合并完成: {merged_data_yaml}\n")
else:
QtWidgets.QMessageBox.critical(self.main_window, "错误", "数据集合并失败")
return False
elif len(dataset_folders) == 1:
# 单个数据集文件夹,需要创建data.yaml文件
single_folder = dataset_folders[0]
data_yaml_path = self._createDataYamlForSingleFolder(single_folder, training_params['exp_name'])
if data_yaml_path:
training_params['save_liquid_data_path'] = data_yaml_path
self._appendLog(f"已为单个数据集创建配置文件: {data_yaml_path}\n")
else:
QtWidgets.QMessageBox.critical(self.main_window, "错误", "创建数据集配置文件失败")
return False
# 禁用笔记保存和提交按钮(训练开始时)
self._disableNotesButtons()
# 更新UI状态
if hasattr(self, 'training_panel'):
if hasattr(self.training_panel, 'train_status_label'):
......@@ -345,38 +399,84 @@ class ModelTrainingHandler(ModelTestHandler):
return False
def _onStopTraining(self):
"""停止训练 - 优雅停止,完成当前epoch后停止"""
"""停止训练 - 根据训练状态采用不同策略"""
# 检查是否已经在停止过程中
if getattr(self, '_is_stopping', False):
self._appendLog("\n[提示] 训练正在停止中,请耐心等待...\n")
return True
if self.training_worker and self.training_active:
self._is_user_stopped = True # 标记为用户手动停止
self.training_worker.stop_training() # 设置 is_running = False,YOLO会在epoch结束时检查
self._appendLog("\n" + "="*70 + "\n")
self._appendLog("用户请求停止训练\n")
self._appendLog("正在完成当前训练轮次...\n")
self._appendLog("(请勿关闭程序,等待当前epoch完成)\n")
self._appendLog("="*70 + "\n")
# 检查训练是否已经真正开始
training_started = self.training_worker.has_training_started()
# 更新状态标签
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_status_label'):
self.training_panel.train_status_label.setText("正在停止训练...")
self.training_panel.train_status_label.setStyleSheet("""
QLabel {
color: #ffffff;
background-color: #ffc107;
border: 1px solid #ffc107;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
""")
FontManager.applyToWidget(self.training_panel.train_status_label, weight=FontManager.WEIGHT_BOLD)
# 禁用停止按钮,防止重复点击
if hasattr(self, 'training_panel'):
self.training_panel.stop_train_btn.setEnabled(False)
# 设置停止标记,防止重复触发
self._is_stopping = True
self._is_user_stopped = True # 标记为用户手动停止
self.training_worker.stop_training() # 设置 is_running = False
# 不立刻终止线程,让YOLO在epoch结束时自动停止
# 线程会在 _onTrainingFinished 中被清理
if not training_started:
# 训练还未真正开始(仍在初始化阶段),直接取消训练
self._appendLog("\n" + "="*70 + "\n")
self._appendLog("训练尚未开始,正在取消训练...\n")
self._appendLog("="*70 + "\n")
# 更新状态标签
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_status_label'):
self.training_panel.train_status_label.setText("正在取消训练...")
self.training_panel.train_status_label.setStyleSheet("""
QLabel {
color: #ffffff;
background-color: #dc3545;
border: 1px solid #dc3545;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
""")
FontManager.applyToWidget(self.training_panel.train_status_label, weight=FontManager.WEIGHT_BOLD)
# 强制终止训练线程(因为训练还未开始,可以安全终止)
if self.training_worker:
self.training_worker.terminate()
self.training_worker.wait(3000) # 等待最多3秒
if self.training_worker.isRunning():
self.training_worker.kill() # 强制杀死线程
# 直接调用训练完成回调,恢复UI状态
self._onTrainingFinished(False)
else:
# 训练已经开始,优雅停止(完成当前epoch后停止)
self._appendLog("\n" + "="*70 + "\n")
self._appendLog("用户请求停止训练\n")
self._appendLog("正在完成当前训练轮次...\n")
self._appendLog("(请勿关闭程序,等待当前epoch完成)\n")
self._appendLog("="*70 + "\n")
# 更新状态标签
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_status_label'):
self.training_panel.train_status_label.setText("正在停止训练...")
self.training_panel.train_status_label.setStyleSheet("""
QLabel {
color: #ffffff;
background-color: #ffc107;
border: 1px solid #ffc107;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
""")
FontManager.applyToWidget(self.training_panel.train_status_label, weight=FontManager.WEIGHT_BOLD)
# 禁用所有训练相关按钮,防止重复点击和冲突
if hasattr(self, 'training_panel'):
if hasattr(self.training_panel, 'stop_train_btn'):
self.training_panel.stop_train_btn.setEnabled(False)
if hasattr(self.training_panel, 'start_train_btn'):
self.training_panel.start_train_btn.setEnabled(False)
# 不立刻终止线程,让YOLO在epoch结束时自动停止
# 线程会在 _onTrainingFinished 中被清理
return True
else:
......@@ -386,7 +486,8 @@ class ModelTrainingHandler(ModelTestHandler):
def _onTrainingFinished(self, success):
"""训练完成回调"""
try:
# 重置停止标记
self._is_stopping = False
self.training_active = False
if success:
......@@ -478,21 +579,30 @@ class ModelTrainingHandler(ModelTestHandler):
self._appendLog(" 训练完成通知\n")
self._appendLog("="*70 + "\n")
self._appendLog("模型升级已完成!\n")
self._appendLog("新模型已保存到detection_model目录\n")
self._appendLog("新模型已自动添加到模型集管理\n")
self._appendLog("请切换到【模型集管理】页面查看新模型\n")
self._appendLog("="*70 + "\n")
# 启用笔记保存和提交按钮(训练完成后允许继续编辑笔记)
self._enableNotesButtons()
# 使用定时器延迟显示消息框,避免阻塞训练线程的清理
QtCore.QTimer.singleShot(500, lambda: QtWidgets.QMessageBox.information(
self.main_window,
"升级完成",
"模型升级已完成!\n新模型已自动添加到模型集管理。"
"模型升级已完成!\n新模型已保存到detection_model目录\n自动添加到模型集管理。"
))
else:
# 检查是否为用户手动停止
is_user_stopped = getattr(self, '_is_user_stopped', False)
if is_user_stopped:
# 检查训练是否已经真正开始
training_started = False
if self.training_worker:
training_started = self.training_worker.has_training_started()
if is_user_stopped and training_started:
self._appendLog("\n" + "="*70 + "\n")
self._appendLog("训练已暂停\n")
self._appendLog("="*70 + "\n")
......@@ -578,6 +688,39 @@ class ModelTrainingHandler(ModelTestHandler):
# 重置标记
self._is_user_stopped = False
self._is_stopping = False
elif is_user_stopped and not training_started:
# 训练被取消(未真正开始),恢复初始状态
self._appendLog("\n" + "="*70 + "\n")
self._appendLog("训练已取消\n")
self._appendLog("="*70 + "\n")
# 更新状态标签
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'train_status_label'):
self.training_panel.train_status_label.setText("训练已取消")
self.training_panel.train_status_label.setStyleSheet("""
QLabel {
color: #ffffff;
background-color: #6c757d;
border: 1px solid #6c757d;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
""")
FontManager.applyToWidget(self.training_panel.train_status_label, weight=FontManager.WEIGHT_BOLD)
# 恢复按钮状态(允许重新开始训练)
if hasattr(self, 'training_panel'):
if hasattr(self.training_panel, 'start_train_btn'):
self.training_panel.start_train_btn.setEnabled(True)
if hasattr(self.training_panel, 'stop_train_btn'):
self.training_panel.stop_train_btn.setEnabled(False)
self.training_panel.stop_train_btn.setText("停止升级") # 恢复原始文本
# 重置标记
self._is_user_stopped = False
self._is_stopping = False
else:
self._appendLog("\n" + "="*70 + "\n")
self._appendLog(" 升级失败\n")
......@@ -609,6 +752,10 @@ class ModelTrainingHandler(ModelTestHandler):
self.training_panel.stop_train_btn.setText("停止升级")
# 重置训练停止标记
self.training_panel._is_training_stopped = False
# 重置停止标记
self._is_user_stopped = False
self._is_stopping = False
# 如果是正常完成(非用户停止),恢复按钮状态
if success and not self._is_user_stopped:
......@@ -891,41 +1038,19 @@ class ModelTrainingHandler(ModelTestHandler):
def _initializeTrainingPanelDefaults(self, training_panel):
"""初始化训练面板的默认值"""
try:
# 设置默认模型路径
if hasattr(training_panel, 'base_model_edit'):
project_root = get_project_root()
available_models = [
os.path.join(project_root, 'database', 'model', 'train_model', '2', 'best.pt'),
os.path.join(project_root, 'database', 'model', 'train_model')
]
for model_path in available_models:
if os.path.exists(model_path):
if os.path.isfile(model_path):
training_panel.base_model_edit.setText(model_path)
break
elif os.path.isdir(model_path):
# 查找目录中的第一个模型文件
for root, dirs, files in os.walk(model_path):
for file in files:
if file.endswith('.pt') or file.endswith('.dat'):
full_path = os.path.join(root, file)
training_panel.base_model_edit.setText(full_path)
break
if training_panel.base_model_edit.text():
break
if training_panel.base_model_edit.text():
break
# 基础模型现在通过下拉菜单从模型集管理页面加载,不需要手动设置
# 下拉菜单会在页面显示时自动加载并选择默认模型
# 设置默认数据集路径
if hasattr(training_panel, 'save_liquid_data_path_edit'):
# 数据集现在通过文件夹列表管理,可以手动添加默认数据集文件夹
if hasattr(training_panel, 'dataset_folders_list'):
project_root = get_project_root()
available_datasets = [
os.path.join(project_root, 'database', 'dataset', 'data.yaml'),
]
for dataset_path in available_datasets:
if os.path.exists(dataset_path):
training_panel.save_liquid_data_path_edit.setText(dataset_path)
break
default_dataset_folder = os.path.join(project_root, 'database', 'dataset')
if os.path.exists(default_dataset_folder) and os.path.isdir(default_dataset_folder):
# 添加默认数据集文件夹(如果列表为空)
if training_panel.dataset_folders_list.count() == 0:
training_panel.dataset_folders_list.addItem(default_dataset_folder)
if hasattr(training_panel, '_updateDatasetPath'):
training_panel._updateDatasetPath()
# 设置默认模型名称
if hasattr(training_panel, 'exp_name_edit'):
......@@ -1174,7 +1299,7 @@ class ModelTrainingHandler(ModelTestHandler):
device_value = device_text
training_params = {
'base_model': getattr(panel, 'base_model_edit', None) and panel.base_model_edit.text() or '',
'base_model': getattr(panel, 'base_model_combo', None) and panel.base_model_combo.currentData() or '',
'save_liquid_data_path': getattr(panel, 'save_liquid_data_path_edit', None) and panel.save_liquid_data_path_edit.text() or '',
'imgsz': getattr(panel, 'imgsz_spin', None) and panel.imgsz_spin.value() or 640,
'epochs': getattr(panel, 'epochs_spin', None) and panel.epochs_spin.value() or 100,
......@@ -1241,34 +1366,39 @@ class ModelTrainingHandler(ModelTestHandler):
# 路径有效,更新为绝对路径
training_params['base_model'] = base_model
# 修正数据集路径
# 修正数据集路径(支持多文件夹,用分号分隔)
save_liquid_data_path = training_params.get('save_liquid_data_path', '')
# 如果是相对路径,转换为绝对路径
if save_liquid_data_path and not os.path.isabs(save_liquid_data_path):
project_root = get_project_root()
save_liquid_data_path = os.path.join(project_root, save_liquid_data_path)
if not save_liquid_data_path or not os.path.exists(save_liquid_data_path):
# 尝试使用配置文件中的默认路径
if self.train_config and 'default_parameters' in self.train_config:
default_data = self.train_config['default_parameters'].get('dataset_path', '')
if default_data and os.path.exists(default_data):
training_params['save_liquid_data_path'] = default_data
else:
# 使用项目中可用的数据集
if save_liquid_data_path:
# 解析多个文件夹路径
folders = [f.strip() for f in save_liquid_data_path.split(';') if f.strip()]
fixed_folders = []
for folder in folders:
# 如果是相对路径,转换为绝对路径
if not os.path.isabs(folder):
project_root = get_project_root()
available_datasets = [
os.path.join(project_root, 'database', 'dataset', 'data_template_1.yaml'),
os.path.join(project_root, 'database', 'dataset', 'data.yaml'),
]
for dataset_path in available_datasets:
if os.path.exists(dataset_path):
training_params['save_liquid_data_path'] = dataset_path
break
folder = os.path.join(project_root, folder)
# 只保留存在的文件夹
if os.path.exists(folder) and os.path.isdir(folder):
fixed_folders.append(folder)
# 更新为绝对路径列表
if fixed_folders:
training_params['save_liquid_data_path'] = ';'.join(fixed_folders)
else:
# 如果没有有效文件夹,尝试使用默认数据集文件夹
project_root = get_project_root()
default_folder = os.path.join(project_root, 'database', 'dataset')
if os.path.exists(default_folder) and os.path.isdir(default_folder):
training_params['save_liquid_data_path'] = default_folder
else:
# 路径有效,更新为绝对路径
training_params['save_liquid_data_path'] = save_liquid_data_path
# 没有指定数据集,使用默认数据集文件夹
project_root = get_project_root()
default_folder = os.path.join(project_root, 'database', 'dataset')
if os.path.exists(default_folder) and os.path.isdir(default_folder):
training_params['save_liquid_data_path'] = default_folder
return training_params
......@@ -1360,6 +1490,78 @@ class ModelTrainingHandler(ModelTestHandler):
except Exception as e:
return False, f"验证过程出错: {str(e)}"
def _validateDatasetFolders(self, dataset_folders):
"""
验证多个数据集文件夹
Args:
dataset_folders: 数据集文件夹路径列表
Returns:
tuple: (是否有效, 错误消息)
"""
try:
if not dataset_folders:
return False, "数据集文件夹列表为空"
# 检查每个文件夹
total_train_images = 0
total_val_images = 0
total_labels = 0
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']
for folder in dataset_folders:
if not os.path.exists(folder):
return False, f"文件夹不存在: {folder}"
if not os.path.isdir(folder):
return False, f"路径不是文件夹: {folder}"
# 检查文件夹结构(YOLO格式)
# 期望结构: folder/images/train, folder/images/val, folder/labels/train, folder/labels/val
images_train_dir = os.path.join(folder, 'images', 'train')
images_val_dir = os.path.join(folder, 'images', 'val')
labels_train_dir = os.path.join(folder, 'labels', 'train')
labels_val_dir = os.path.join(folder, 'labels', 'val')
# 检查是否存在训练图片目录
if os.path.exists(images_train_dir):
train_count = sum(1 for f in os.listdir(images_train_dir)
if any(f.lower().endswith(ext) for ext in image_extensions))
total_train_images += train_count
# 检查是否存在验证图片目录
if os.path.exists(images_val_dir):
val_count = sum(1 for f in os.listdir(images_val_dir)
if any(f.lower().endswith(ext) for ext in image_extensions))
total_val_images += val_count
# 检查标签文件
if os.path.exists(labels_train_dir):
label_count = sum(1 for f in os.listdir(labels_train_dir)
if f.lower().endswith('.txt'))
total_labels += label_count
# 验证是否有足够的数据
if total_train_images == 0 and total_val_images == 0:
return False, f"所有数据集文件夹中都没有找到图片文件\n请确保文件夹包含 images/train 或 images/val 子目录"
if total_train_images == 0:
return False, f"未找到训练图片\n请确保至少一个文件夹包含 images/train 子目录及图片文件"
# 返回验证结果
msg = f"数据集验证通过\n"
msg += f"总训练图片: {total_train_images} 张\n"
if total_val_images > 0:
msg += f"总验证图片: {total_val_images} 张\n"
if total_labels > 0:
msg += f"总标注文件: {total_labels} 个"
return True, msg
except Exception as e:
return False, f"验证过程出错: {str(e)}"
def _validateTrainingData(self, save_liquid_data_path):
"""验证训练数据(简化版,用于向后兼容)"""
result, _ = self._validateTrainingDataWithDetails(save_liquid_data_path)
......@@ -1554,15 +1756,15 @@ class ModelTrainingHandler(ModelTestHandler):
)
return
# 改为从QComboBox获取数据
test_file_path = self.training_panel.test_file_input.currentData() or ""
test_file_display = self.training_panel.test_file_input.currentText()
# 从QLineEdit获取测试文件路径(浏览选择的文件)
test_file_path = self.training_panel.test_file_input.text().strip()
test_file_display = os.path.basename(test_file_path) if test_file_path else ""
if not test_file_path:
QtWidgets.QMessageBox.warning(
self.training_panel,
"提示",
"请先选择测试文件"
'请点击"浏览..."按钮选择测试文件'
)
return
......@@ -2183,11 +2385,21 @@ class ModelTrainingHandler(ModelTestHandler):
'training_date': self._getCurrentTimestamp()
}
# 获取训练笔记
training_notes = self._getTrainingNotes()
# 将模型移动到detection_model目录
detection_model_path = self._moveModelToDetectionDir(best_model_path, model_name, weights_dir, training_notes)
if detection_model_path:
# 更新模型参数中的路径
model_params['path'] = detection_model_path
self._appendLog(f"模型已移动到detection_model目录: {os.path.basename(detection_model_path)}\n")
# 保存到模型集配置
self._saveModelToConfig(model_name, model_params)
self._appendLog(f"新模型已添加到模型集: {model_name}\n")
self._appendLog(f" 路径: {best_model_path}\n")
self._appendLog(f" 路径: {model_params['path']}\n")
self._appendLog(f" 大小: {model_size}\n")
except Exception as e:
......@@ -2195,8 +2407,216 @@ class ModelTrainingHandler(ModelTestHandler):
import traceback
traceback.print_exc()
def _getTrainingNotes(self):
"""获取训练页面的笔记内容"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'getTrainingNotes'):
notes = self.training_panel.getTrainingNotes()
if notes:
self._appendLog(f"[笔记] 获取到训练笔记,长度: {len(notes)} 字符\n")
return notes
else:
self._appendLog("[笔记] 未输入训练笔记\n")
return ""
else:
self._appendLog("[笔记] 无法获取训练页面笔记接口\n")
return ""
except Exception as e:
self._appendLog(f"[ERROR] 获取训练笔记失败: {str(e)}\n")
return ""
def _clearTrainingNotes(self):
"""清空训练页面的笔记内容"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'clearTrainingNotes'):
self.training_panel.clearTrainingNotes()
self._appendLog("[笔记] 训练笔记已清空\n")
else:
self._appendLog("[笔记] 无法获取训练页面笔记清空接口\n")
except Exception as e:
self._appendLog(f"[ERROR] 清空训练笔记失败: {str(e)}\n")
def _enableNotesButtons(self):
"""启用训练页面的笔记保存和提交按钮"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'enableNotesButtons'):
self.training_panel.enableNotesButtons()
self._appendLog("[笔记] 笔记保存和提交按钮已启用\n")
else:
self._appendLog("[笔记] 无法获取训练页面笔记按钮接口\n")
except Exception as e:
self._appendLog(f"[ERROR] 启用笔记按钮失败: {str(e)}\n")
def _disableNotesButtons(self):
"""禁用训练页面的笔记保存和提交按钮"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'disableNotesButtons'):
self.training_panel.disableNotesButtons()
self._appendLog("[笔记] 笔记保存和提交按钮已禁用\n")
else:
self._appendLog("[笔记] 无法获取训练页面笔记按钮接口\n")
except Exception as e:
self._appendLog(f"[ERROR] 禁用笔记按钮失败: {str(e)}\n")
def saveNotesToLatestModel(self, notes):
"""保存笔记到最新训练的模型目录"""
try:
if not notes or not notes.strip():
self._appendLog("[笔记] 笔记内容为空,无需保存\n")
return False
# 获取最新的模型目录
latest_model_dir = self._getLatestModelDirectory()
if not latest_model_dir:
self._appendLog("[ERROR] 无法找到最新的模型目录\n")
return False
# 保存笔记文件
notes_file = os.path.join(latest_model_dir, 'training_notes.txt')
try:
with open(notes_file, 'w', encoding='utf-8') as f:
# 添加时间戳和更新信息
f.write(f"训练笔记 - {os.path.basename(latest_model_dir)}\n")
f.write(f"最后更新: {self._getCurrentTimestamp()}\n")
f.write("="*50 + "\n\n")
f.write(notes)
self._appendLog(f"[笔记] 笔记已保存到: {notes_file}\n")
return True
except Exception as e:
self._appendLog(f"[ERROR] 保存笔记文件失败: {str(e)}\n")
return False
except Exception as e:
self._appendLog(f"[ERROR] 保存笔记到最新模型失败: {str(e)}\n")
import traceback
traceback.print_exc()
return False
def _getLatestModelDirectory(self):
"""获取最新的detection_model目录"""
try:
project_root = get_project_root()
detection_model_dir = os.path.join(project_root, 'database', 'model', 'detection_model')
if not os.path.exists(detection_model_dir):
return None
# 获取所有数字目录
digit_dirs = []
for item in os.listdir(detection_model_dir):
item_path = os.path.join(detection_model_dir, item)
if os.path.isdir(item_path) and item.isdigit():
digit_dirs.append(int(item))
if not digit_dirs:
return None
# 返回最大数字的目录
latest_id = max(digit_dirs)
latest_dir = os.path.join(detection_model_dir, str(latest_id))
self._appendLog(f"[笔记] 找到最新模型目录: detection_model/{latest_id}\n")
return latest_dir
except Exception as e:
self._appendLog(f"[ERROR] 获取最新模型目录失败: {str(e)}\n")
return None
def _moveModelToDetectionDir(self, model_path, model_name, weights_dir, training_notes=""):
"""将训练完成的模型移动到detection_model目录"""
try:
import shutil
from pathlib import Path
self._appendLog(f"\n开始移动模型到detection_model目录...\n")
# 获取项目根目录
project_root = get_project_root()
detection_model_dir = os.path.join(project_root, 'database', 'model', 'detection_model')
# 确保detection_model目录存在
os.makedirs(detection_model_dir, exist_ok=True)
# 获取下一个可用的数字ID
existing_dirs = []
for item in os.listdir(detection_model_dir):
item_path = os.path.join(detection_model_dir, item)
if os.path.isdir(item_path) and item.isdigit():
existing_dirs.append(int(item))
next_id = max(existing_dirs) + 1 if existing_dirs else 1
target_model_dir = os.path.join(detection_model_dir, str(next_id))
self._appendLog(f" 目标目录: {target_model_dir}\n")
# 创建目标目录结构
os.makedirs(target_model_dir, exist_ok=True)
target_weights_dir = os.path.join(target_model_dir, 'weights')
os.makedirs(target_weights_dir, exist_ok=True)
# 移动整个weights目录的内容
if weights_dir and os.path.exists(weights_dir):
self._appendLog(f" 复制weights目录内容...\n")
# 复制所有文件到目标weights目录
for filename in os.listdir(weights_dir):
source_file = os.path.join(weights_dir, filename)
target_file = os.path.join(target_weights_dir, filename)
if os.path.isfile(source_file):
shutil.copy2(source_file, target_file)
self._appendLog(f" 复制: {filename}\n")
# 复制训练目录的其他文件(如config.yaml, results.csv等)
train_exp_dir = os.path.dirname(weights_dir)
if os.path.exists(train_exp_dir):
for filename in os.listdir(train_exp_dir):
if filename != 'weights': # 跳过weights目录
source_file = os.path.join(train_exp_dir, filename)
target_file = os.path.join(target_model_dir, filename)
if os.path.isfile(source_file):
shutil.copy2(source_file, target_file)
self._appendLog(f" 复制配置: {filename}\n")
# 保存训练笔记(如果有)
if training_notes:
notes_file = os.path.join(target_model_dir, 'training_notes.txt')
try:
with open(notes_file, 'w', encoding='utf-8') as f:
# 添加时间戳和模型信息
f.write(f"训练笔记 - {model_name}\n")
f.write(f"训练时间: {self._getCurrentTimestamp()}\n")
f.write(f"模型ID: {next_id}\n")
f.write("="*50 + "\n\n")
f.write(training_notes)
self._appendLog(f" 保存训练笔记: training_notes.txt\n")
except Exception as e:
self._appendLog(f" 保存训练笔记失败: {str(e)}\n")
# 确定最终的模型文件路径
model_filename = os.path.basename(model_path)
final_model_path = os.path.join(target_weights_dir, model_filename)
if os.path.exists(final_model_path):
self._appendLog(f"✅ 模型已成功移动到detection_model/{next_id}/weights/{model_filename}\n")
if training_notes:
self._appendLog(f"✅ 训练笔记已保存到detection_model/{next_id}/training_notes.txt\n")
return final_model_path
else:
self._appendLog(f"❌ 模型移动失败,目标文件不存在: {final_model_path}\n")
return None
except Exception as e:
self._appendLog(f"❌ [ERROR] 移动模型到detection_model失败: {str(e)}\n")
import traceback
traceback.print_exc()
return None
def _saveModelToConfig(self, model_name, model_params):
"""保存模型配置文件(模型已经在train_model目录中)"""
"""保存模型配置文件(模型已经在detection_model目录中)"""
try:
from pathlib import Path
import yaml
......@@ -2205,7 +2625,7 @@ class ModelTrainingHandler(ModelTestHandler):
self._appendLog(f" 模型名称: {model_name}\n")
self._appendLog(f" 模型路径: {model_params['path']}\n")
# 获取模型所在目录(应该已经在train_model/{数字ID}/weights/中)
# 获取模型所在目录(现在应该在detection_model/{数字ID}/weights/中)
model_path = model_params['path']
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
......@@ -2254,7 +2674,8 @@ class ModelTrainingHandler(ModelTestHandler):
yaml.dump(model_params, f, allow_unicode=True, default_flow_style=False)
# 输出总结信息
self._appendLog(f"\n✅ 模型配置已保存: {model_dir}\n")
model_id = os.path.basename(model_dir)
self._appendLog(f"\n✅ 模型配置已保存到detection_model/{model_id}/\n")
self._appendLog(f" 找到的文件:\n")
for info in model_files_found:
self._appendLog(f" - {info}\n")
......@@ -2262,6 +2683,7 @@ class ModelTrainingHandler(ModelTestHandler):
for log_file in log_files_found:
self._appendLog(f" - 训练日志: {log_file}\n")
self._appendLog(f" - 配置文件: config.yaml\n")
self._appendLog(f"\n🎉 模型已成功保存到detection_model目录,可在模型集管理中查看!\n")
except Exception as e:
self._appendLog(f"❌ [ERROR] 保存模型配置失败: {str(e)}\n")
......@@ -2652,3 +3074,178 @@ class ModelTrainingHandler(ModelTestHandler):
return "未知"
except:
return "未知"
def _mergeMultipleDatasets(self, dataset_folders, exp_name):
"""
合并多个数据集文件夹为统一的训练配置
Args:
dataset_folders: 数据集文件夹路径列表
exp_name: 实验名称
Returns:
str: 合并后的data.yaml文件路径,失败返回None
"""
try:
import shutil
import tempfile
from pathlib import Path
# 创建临时合并目录
project_root = get_project_root()
temp_dataset_dir = os.path.join(project_root, 'database', 'temp_datasets', exp_name)
# 如果目录已存在,先删除
if os.path.exists(temp_dataset_dir):
shutil.rmtree(temp_dataset_dir)
# 创建合并后的目录结构
merged_images_train = os.path.join(temp_dataset_dir, 'images', 'train')
merged_images_val = os.path.join(temp_dataset_dir, 'images', 'val')
merged_labels_train = os.path.join(temp_dataset_dir, 'labels', 'train')
merged_labels_val = os.path.join(temp_dataset_dir, 'labels', 'val')
os.makedirs(merged_images_train, exist_ok=True)
os.makedirs(merged_images_val, exist_ok=True)
os.makedirs(merged_labels_train, exist_ok=True)
os.makedirs(merged_labels_val, exist_ok=True)
self._appendLog(f"创建合并目录: {temp_dataset_dir}\n")
# 合并所有数据集
total_train_images = 0
total_val_images = 0
total_train_labels = 0
total_val_labels = 0
for i, folder in enumerate(dataset_folders):
self._appendLog(f"正在合并数据集 {i+1}/{len(dataset_folders)}: {os.path.basename(folder)}\n")
# 检查源目录结构
src_images_train = os.path.join(folder, 'images', 'train')
src_images_val = os.path.join(folder, 'images', 'val')
src_labels_train = os.path.join(folder, 'labels', 'train')
src_labels_val = os.path.join(folder, 'labels', 'val')
# 复制训练图片
if os.path.exists(src_images_train):
for filename in os.listdir(src_images_train):
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')):
src_file = os.path.join(src_images_train, filename)
# 添加前缀避免文件名冲突
dst_filename = f"ds{i+1}_{filename}"
dst_file = os.path.join(merged_images_train, dst_filename)
shutil.copy2(src_file, dst_file)
total_train_images += 1
# 复制验证图片
if os.path.exists(src_images_val):
for filename in os.listdir(src_images_val):
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')):
src_file = os.path.join(src_images_val, filename)
# 添加前缀避免文件名冲突
dst_filename = f"ds{i+1}_{filename}"
dst_file = os.path.join(merged_images_val, dst_filename)
shutil.copy2(src_file, dst_file)
total_val_images += 1
# 复制训练标签
if os.path.exists(src_labels_train):
for filename in os.listdir(src_labels_train):
if filename.lower().endswith('.txt'):
src_file = os.path.join(src_labels_train, filename)
# 添加前缀避免文件名冲突,保持与图片文件名对应
dst_filename = f"ds{i+1}_{filename}"
dst_file = os.path.join(merged_labels_train, dst_filename)
shutil.copy2(src_file, dst_file)
total_train_labels += 1
# 复制验证标签
if os.path.exists(src_labels_val):
for filename in os.listdir(src_labels_val):
if filename.lower().endswith('.txt'):
src_file = os.path.join(src_labels_val, filename)
# 添加前缀避免文件名冲突,保持与图片文件名对应
dst_filename = f"ds{i+1}_{filename}"
dst_file = os.path.join(merged_labels_val, dst_filename)
shutil.copy2(src_file, dst_file)
total_val_labels += 1
self._appendLog(f"合并完成: 训练图片 {total_train_images} 张, 验证图片 {total_val_images} 张\n")
self._appendLog(f"合并完成: 训练标签 {total_train_labels} 个, 验证标签 {total_val_labels} 个\n")
# 创建data.yaml配置文件
data_yaml_path = os.path.join(temp_dataset_dir, 'data.yaml')
data_config = {
'train': os.path.join(temp_dataset_dir, 'images', 'train'),
'val': os.path.join(temp_dataset_dir, 'images', 'val'),
'nc': 1, # 类别数量,液位检测通常是单类别
'names': ['liquid_level'] # 类别名称
}
with open(data_yaml_path, 'w', encoding='utf-8') as f:
yaml.dump(data_config, f, default_flow_style=False, allow_unicode=True)
self._appendLog(f"创建配置文件: {data_yaml_path}\n")
return data_yaml_path
except Exception as e:
self._appendLog(f"数据集合并失败: {str(e)}\n")
import traceback
self._appendLog(traceback.format_exc())
return None
def _createDataYamlForSingleFolder(self, dataset_folder, exp_name):
"""
为单个数据集文件夹创建data.yaml配置文件
Args:
dataset_folder: 数据集文件夹路径
exp_name: 实验名称
Returns:
str: data.yaml文件路径,失败返回None
"""
try:
# 检查是否已经有data.yaml文件
existing_yaml = os.path.join(dataset_folder, 'data.yaml')
if os.path.exists(existing_yaml):
self._appendLog(f"使用现有配置文件: {existing_yaml}\n")
return existing_yaml
# 创建临时配置文件
project_root = get_project_root()
temp_config_dir = os.path.join(project_root, 'database', 'temp_configs')
os.makedirs(temp_config_dir, exist_ok=True)
data_yaml_path = os.path.join(temp_config_dir, f"{exp_name}_data.yaml")
# 检查数据集目录结构
images_train_dir = os.path.join(dataset_folder, 'images', 'train')
images_val_dir = os.path.join(dataset_folder, 'images', 'val')
if not os.path.exists(images_train_dir):
self._appendLog(f"错误: 训练图片目录不存在: {images_train_dir}\n")
return None
# 创建data.yaml配置
data_config = {
'train': images_train_dir,
'val': images_val_dir if os.path.exists(images_val_dir) else images_train_dir,
'nc': 1, # 类别数量,液位检测通常是单类别
'names': ['liquid_level'] # 类别名称
}
with open(data_yaml_path, 'w', encoding='utf-8') as f:
yaml.dump(data_config, f, default_flow_style=False, allow_unicode=True)
self._appendLog(f"创建配置文件: {data_yaml_path}\n")
return data_yaml_path
except Exception as e:
self._appendLog(f"创建配置文件失败: {str(e)}\n")
import traceback
self._appendLog(traceback.format_exc())
return None
......@@ -69,6 +69,7 @@ class TrainingWorker(QThread):
super().__init__()
self.training_params = training_params
self.is_running = True
self.training_actually_started = False # 标记训练是否已经真正开始(第一个epoch开始)
self.train_config = None
# 调试信息:显示传入的训练参数
......@@ -522,6 +523,8 @@ class TrainingWorker(QThread):
def on_train_start(trainer):
"""训练开始回调 - 只输出到终端,不发送到UI"""
# 标记训练已经真正开始
self.training_actually_started = True
# 记录开始时间
epoch_start_time[0] = time.time()
# 不发送任何格式化消息到UI,让LogCapture直接捕获原生输出
......@@ -1264,3 +1267,7 @@ class TrainingWorker(QThread):
def stop_training(self):
"""停止训练"""
self.is_running = False
def has_training_started(self):
"""检查训练是否已经真正开始"""
return self.training_actually_started
......@@ -1900,6 +1900,10 @@ class ChannelPanelHandler:
def _initConfigFileWatcher(self):
"""初始化配置文件监控器"""
try:
# 临时禁用配置文件监控器以解决 QWidget 创建顺序问题
print(f"[ConfigWatcher] 配置文件监控器已禁用(避免 QWidget 创建顺序问题)")
return
# 获取配置文件路径
project_root = get_project_root()
config_path = os.path.join(project_root, 'database', 'config', 'default_config.yaml')
......@@ -1928,7 +1932,12 @@ class ChannelPanelHandler:
print(f"🔄 [ConfigWatcher] 检测到配置文件变化: {path}")
# 延迟一小段时间,确保文件写入完成
QtCore.QTimer.singleShot(100, self._reloadChannelConfig)
# 修复:检查 QApplication 是否存在
if QtWidgets.QApplication.instance() is not None:
QtCore.QTimer.singleShot(100, self._reloadChannelConfig)
else:
# 如果没有 QApplication,直接调用重载函数
self._reloadChannelConfig()
except Exception as e:
print(f"[ConfigWatcher] 处理配置文件变化失败: {e}")
......
......@@ -17,7 +17,7 @@ import csv
import datetime
import numpy as np
from qtpy import QtWidgets, QtCore
from PyQt5.QtCore import QThread, pyqtSignal
from qtpy.QtCore import QThread, Signal as pyqtSignal
# 导入图标工具
try:
......@@ -930,7 +930,7 @@ class CurvePanelHandler:
QtWidgets.QApplication.processEvents()
# 🔥 延迟关闭进度条,确保用户能看到(至少显示500ms)
from PyQt5.QtCore import QTimer
from qtpy.QtCore import QTimer
QTimer.singleShot(500, progress_dialog.close)
print(f"✅ [进度条] 将在500ms后关闭")
......
......@@ -11,7 +11,7 @@ widgets/videopage/historypanel.py (HistoryPanel)
-
"""
from PyQt5 import QtWidgets, QtCore
from qtpy import QtWidgets, QtCore
class HistoryPanelHandler:
......@@ -80,8 +80,8 @@ class HistoryPanelHandler:
return
import os
from PyQt5.QtMultimedia import QMediaPlayer
from PyQt5 import QtWidgets
from qtpy.QtMultimedia import QMediaPlayer
from qtpy import QtWidgets
player = self.history_panel._media_player
print(f"[HistoryPanelHandler] 播放器: {player}")
......@@ -158,7 +158,7 @@ class HistoryPanelHandler:
if not self.history_panel or not hasattr(self.history_panel, 'play_pause_button'):
return
from PyQt5.QtMultimedia import QMediaPlayer
from qtpy.QtMultimedia import QMediaPlayer
from widgets.style_manager import newIcon
button = self.history_panel.play_pause_button
......
......@@ -699,7 +699,13 @@ class ModelSetPage(QtWidgets.QWidget):
menu.addSeparator()
# 5. 删除模型(有实际功能)
# 5. 查看模型信息(新功能)
action_view_info = menu.addAction("查看模型信息")
action_view_info.triggered.connect(lambda: self.viewModelInfo(model_name))
menu.addSeparator()
# 6. 删除模型(有实际功能)
action_delete = menu.addAction("删除模型")
action_delete.triggered.connect(lambda: self.deleteModel(model_name))
......@@ -734,6 +740,9 @@ class ModelSetPage(QtWidgets.QWidget):
self.defaultModelChanged.emit(model_name)
self.setDefaultRequested.emit(model_name)
self._updateStats()
# 通知训练页面刷新基础模型列表
self._notifyTrainingPageRefresh()
def updateModelParams(self, model_name):
"""更新模型参数显示(已移除右侧参数显示,保留方法以保持兼容性)"""
......@@ -1007,7 +1016,8 @@ class ModelSetPage(QtWidgets.QWidget):
self._updateModelOrder()
# 通知训练页面刷新基础模型列表
self._notifyTrainingPageRefresh()
except Exception as e:
showCritical(self, "删除失败", f"删除模型时发生错误: {e}")
......@@ -1039,6 +1049,216 @@ class ModelSetPage(QtWidgets.QWidget):
except Exception as e:
showCritical(self, "操作失败", f"添加至检测模型时发生错误: {e}")
def viewModelInfo(self, model_name):
"""查看模型信息(来源和训练指标)"""
try:
# 获取模型参数
if model_name not in self._model_params:
showWarning(self, "错误", f"未找到模型 '{model_name}' 的信息")
return
model_params = self._model_params[model_name]
model_path = model_params.get('path', '')
if not model_path or not os.path.exists(model_path):
showWarning(self, "错误", f"模型文件不存在: {model_path}")
return
# 读取模型配置和训练指标
model_info = self._readModelTrainingInfo(model_path, model_name)
# 创建并显示信息对话框
self._showModelInfoDialog(model_name, model_params, model_info)
except Exception as e:
import traceback
traceback.print_exc()
showCritical(self, "错误", f"查看模型信息时发生错误: {e}")
def _readModelTrainingInfo(self, model_path, model_name):
"""读取模型的训练配置和指标信息"""
info = {
'config': {},
'metrics': {},
'training_date': None,
'source': '未知'
}
try:
from pathlib import Path
model_file = Path(model_path)
# 1. 确定模型所在目录
if model_file.is_file():
model_dir = model_file.parent
else:
model_dir = model_file
# 2. 尝试读取config.yaml(训练配置)
config_file = model_dir / 'config.yaml'
if not config_file.exists():
# 尝试上一级目录
config_file = model_dir.parent / 'config.yaml'
if config_file.exists():
with open(config_file, 'r', encoding='utf-8') as f:
info['config'] = yaml.safe_load(f) or {}
info['training_date'] = info['config'].get('training_date', '未知')
info['source'] = '训练生成'
# 3. 尝试读取results.csv(训练指标)
results_file = model_dir / 'results.csv'
if not results_file.exists():
# 尝试train子目录
results_file = model_dir.parent / 'train' / 'results.csv'
if results_file.exists():
import pandas as pd
try:
df = pd.read_csv(results_file)
if len(df) > 0:
# 获取最后一行(最终指标)
last_row = df.iloc[-1]
info['metrics'] = {
'epoch': int(last_row.get('epoch', 0)) if 'epoch' in last_row else len(df),
'train_loss': float(last_row.get('train/box_loss', 0)) if 'train/box_loss' in last_row else None,
'val_loss': float(last_row.get('val/box_loss', 0)) if 'val/box_loss' in last_row else None,
'precision': float(last_row.get('metrics/precision(B)', 0)) if 'metrics/precision(B)' in last_row else None,
'recall': float(last_row.get('metrics/recall(B)', 0)) if 'metrics/recall(B)' in last_row else None,
'mAP50': float(last_row.get('metrics/mAP50(B)', 0)) if 'metrics/mAP50(B)' in last_row else None,
'mAP50-95': float(last_row.get('metrics/mAP50-95(B)', 0)) if 'metrics/mAP50-95(B)' in last_row else None,
}
except Exception as e:
print(f"读取results.csv失败: {e}")
# 4. 如果没有配置文件,从文件路径判断来源
if not info['config']:
if 'train_model' in str(model_path):
info['source'] = '本地训练'
elif 'detection_model' in str(model_path):
info['source'] = '检测模型'
else:
info['source'] = '导入模型'
except Exception as e:
import traceback
traceback.print_exc()
print(f"读取模型训练信息失败: {e}")
return info
def _showModelInfoDialog(self, model_name, model_params, model_info):
"""显示模型信息对话框(简化版 - 只显示训练配置)"""
# 创建对话框
dialog = QtWidgets.QDialog(self)
dialog.setWindowTitle(f"模型升级配置 - {model_name}")
dialog.setMinimumWidth(450)
dialog.setMinimumHeight(400)
layout = QtWidgets.QVBoxLayout(dialog)
layout.setSpacing(15)
layout.setContentsMargins(20, 20, 20, 20)
# 标题
title_label = QtWidgets.QLabel(f"<h3>{model_name}</h3>")
layout.addWidget(title_label)
# 创建表单布局显示训练配置
form_widget = QtWidgets.QWidget()
form_layout = QtWidgets.QFormLayout(form_widget)
form_layout.setSpacing(10)
form_layout.setContentsMargins(10, 10, 10, 10)
form_layout.setLabelAlignment(Qt.AlignRight | Qt.AlignVCenter)
# 设置表单样式
form_widget.setStyleSheet("""
QWidget {
background-color: #f5f5f5;
border: 1px solid #ddd;
border-radius: 4px;
}
QLabel {
font-size: 10pt;
padding: 3px;
}
""")
# 从配置中提取训练参数
config = model_info.get('config', {})
# 基础信息
info_label = QtWidgets.QLabel(f"<b>模型来源:</b>{model_info['source']}")
form_layout.addRow("", info_label)
if model_info.get('training_date'):
date_label = QtWidgets.QLabel(f"<b>训练日期:</b>{model_info['training_date']}")
form_layout.addRow("", date_label)
# 添加分隔线
separator1 = QtWidgets.QFrame()
separator1.setFrameShape(QtWidgets.QFrame.HLine)
separator1.setFrameShadow(QtWidgets.QFrame.Sunken)
form_layout.addRow(separator1)
# 训练配置参数(仿照训练页面的格式)
epochs = config.get('epochs', '未知')
batch_size = config.get('batch', config.get('batch_size', '未知'))
imgsz = config.get('imgsz', '未知')
workers = config.get('workers', '未知')
device = config.get('device', '未知')
optimizer = config.get('optimizer', '未知')
form_layout.addRow("训练轮数:", QtWidgets.QLabel(f"<b>{epochs} 轮</b>"))
form_layout.addRow("批次大小:", QtWidgets.QLabel(f"<b>{batch_size}</b>"))
form_layout.addRow("图像尺寸:", QtWidgets.QLabel(f"<b>{imgsz} px</b>"))
form_layout.addRow("Workers:", QtWidgets.QLabel(f"<b>{workers} 线程</b>"))
form_layout.addRow("训练设备:", QtWidgets.QLabel(f"<b>{device}</b>"))
form_layout.addRow("优化器:", QtWidgets.QLabel(f"<b>{optimizer}</b>"))
# 如果有训练指标,显示最终性能
metrics = model_info.get('metrics', {})
if metrics:
separator2 = QtWidgets.QFrame()
separator2.setFrameShape(QtWidgets.QFrame.HLine)
separator2.setFrameShadow(QtWidgets.QFrame.Sunken)
form_layout.addRow(separator2)
mAP50 = metrics.get('mAP50')
mAP5095 = metrics.get('mAP50-95')
if mAP50 is not None:
form_layout.addRow("mAP@0.5:", QtWidgets.QLabel(f"<b style='color: #28a745;'>{mAP50:.4f}</b>"))
if mAP5095 is not None:
form_layout.addRow("mAP@0.5:0.95:", QtWidgets.QLabel(f"<b style='color: #28a745;'>{mAP5095:.4f}</b>"))
layout.addWidget(form_widget)
# 如果没有训练配置,显示提示
if not config:
no_config_label = QtWidgets.QLabel(
"<i>该模型没有训练配置信息<br>"
"可能是外部导入的模型</i>"
)
no_config_label.setAlignment(Qt.AlignCenter)
no_config_label.setStyleSheet("color: #999; padding: 20px;")
layout.addWidget(no_config_label)
layout.addStretch()
# 按钮
button_layout = QtWidgets.QHBoxLayout()
button_layout.addStretch()
close_btn = QtWidgets.QPushButton("关闭")
close_btn.setMinimumWidth(100)
close_btn.clicked.connect(dialog.accept)
button_layout.addWidget(close_btn)
layout.addLayout(button_layout)
# 显示对话框
dialog.exec_()
def _moveModelToDetection(self, model_name, source_path):
"""执行模型移动操作"""
try:
......@@ -1149,10 +1369,25 @@ class ModelSetPage(QtWidgets.QWidget):
self._updateModelOrder()
# 通知训练页面刷新基础模型列表
self._notifyTrainingPageRefresh()
except Exception as e:
import traceback
traceback.print_exc()
def _notifyTrainingPageRefresh(self):
"""通知训练页面刷新基础模型列表"""
try:
# 尝试获取训练页面并刷新其基础模型列表
if self._parent and hasattr(self._parent, 'trainingPage'):
training_page = self._parent.trainingPage
if hasattr(training_page, '_loadBaseModelOptions'):
training_page._loadBaseModelOptions()
print("[模型集管理] 已通知训练页面刷新基础模型列表")
except Exception as e:
print(f"[模型集管理] 通知训练页面刷新失败: {e}")
def _loadConfigFile(self):
"""加载配置文件"""
try:
......@@ -1208,50 +1443,122 @@ class ModelSetPage(QtWidgets.QWidget):
return models
def _scanModelDirectory(self):
"""扫描模型目录获取所有模型文件"""
"""扫描模型目录获取所有模型文件(优先detection_model)"""
models = []
try:
# 获取模型目录路径
current_dir = Path(__file__).parent.parent.parent
# 扫描多个模型目录
# 扫描多个模型目录(优先detection_model)
model_dirs = [
(current_dir / "database" / "model" / "detection_model", "检测模型"),
(current_dir / "database" / "model" / "train_model", "训练模型"),
(current_dir / "database" / "model" / "test_model", "测试模型")
(current_dir / "database" / "model" / "detection_model", "检测模型", True), # 优先级最高
(current_dir / "database" / "model" / "train_model", "训练模型", False),
(current_dir / "database" / "model" / "test_model", "测试模型", False)
]
for model_dir, dir_type in model_dirs:
for model_dir, dir_type, is_primary in model_dirs:
if not model_dir.exists():
continue
# 遍历所有子目录,并按目录名排序(确保模型1最先)
sorted_subdirs = sorted(model_dir.iterdir(), key=lambda x: x.name if x.is_dir() else '')
# 遍历所有子目录,按数字排序(降序,最新的在前)
all_subdirs = [d for d in model_dir.iterdir() if d.is_dir()]
digit_subdirs = [d for d in all_subdirs if d.name.isdigit()]
sorted_subdirs = sorted(digit_subdirs, key=lambda x: int(x.name), reverse=True)
for subdir in sorted_subdirs:
if subdir.is_dir():
# 查找 .dat 文件(优先)
for model_file in sorted(subdir.glob("*.dat")):
models.append({
'name': f"{dir_type}-{subdir.name}-{model_file.stem}",
'path': str(model_file),
'subdir': subdir.name,
'source': 'scan',
'format': 'dat',
'model_type': dir_type
})
# 检查是否有weights子目录
weights_dir = subdir / "weights"
search_dir = weights_dir if weights_dir.exists() else subdir
# 尝试读取config.yaml获取模型名称
config_file = subdir / "config.yaml"
model_display_name = None
if config_file.exists():
try:
import yaml
with open(config_file, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
if config_data and 'name' in config_data:
model_display_name = config_data['name']
except Exception:
pass
# 如果没有配置文件中的名称,使用默认命名
if not model_display_name:
if is_primary:
model_display_name = f"模型-{subdir.name}"
else:
model_display_name = f"{dir_type}-{subdir.name}"
# 按优先级查找模型文件:best > last > epoch1
selected_model = None
# 优先级1: best模型(.dat优先)
for ext in ['.dat', '']: # 无扩展名的也考虑
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('best.'):
if ext == '' and '.' in file.name[5:]: # 有其他扩展名
continue
if ext != '' and not file.name.endswith(ext):
continue
if ext == '' or file.name.endswith(ext):
selected_model = file
break
if selected_model:
break
# 优先级2: last模型
if not selected_model:
for ext in ['.dat', '']:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('last.'):
if ext == '' and '.' in file.name[5:]:
continue
if ext != '' and not file.name.endswith(ext):
continue
if ext == '' or file.name.endswith(ext):
selected_model = file
break
if selected_model:
break
# 优先级3: epoch1模型
if not selected_model:
for ext in ['.dat', '']:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('epoch1.'):
if ext == '' and '.' in file.name[7:]:
continue
if ext != '' and not file.name.endswith(ext):
continue
if ext == '' or file.name.endswith(ext):
selected_model = file
break
if selected_model:
break
# 如果找到了模型文件,添加到列表
if selected_model:
# 获取文件格式
file_ext = selected_model.suffix.lstrip('.')
if not file_ext:
# 处理无扩展名的情况
if '.' in selected_model.name:
file_ext = selected_model.name.split('.')[-1]
else:
file_ext = 'dat' # 默认为dat格式
# 然后查找 .pt 文件
for model_file in sorted(subdir.glob("*.pt")):
models.append({
'name': f"{dir_type}-{subdir.name}-{model_file.stem}",
'path': str(model_file),
'subdir': subdir.name,
'source': 'scan',
'format': 'pt',
'model_type': dir_type
})
models.append({
'name': model_display_name,
'path': str(selected_model),
'subdir': subdir.name,
'source': 'scan',
'format': file_ext,
'model_type': dir_type,
'is_primary': is_primary, # 标记是否为主要模型目录
'file_name': selected_model.name
})
except Exception as e:
import traceback
......@@ -1260,29 +1567,47 @@ class ModelSetPage(QtWidgets.QWidget):
return models
def _mergeModelInfo(self, channel_models, scanned_models):
"""合并模型信息,避免重复"""
"""合并模型信息,避免重复(优先detection_model)"""
all_models = []
seen_paths = set()
# 优先添加配置文件中的通道模型
# 首先添加detection_model中的主要模型(优先级最高)
primary_models = [m for m in scanned_models if m.get('is_primary', False)]
for model in primary_models:
path = model['path']
if path not in seen_paths:
all_models.append(model)
seen_paths.add(path)
# 然后添加配置文件中的通道模型
for model in channel_models:
path = model['path']
if path not in seen_paths:
all_models.append(model)
seen_paths.add(path)
# 再添加扫描到的模型(跳过已存在的)
for model in scanned_models:
# 最后添加其他扫描到的模型(跳过已存在的)
other_models = [m for m in scanned_models if not m.get('is_primary', False)]
for model in other_models:
path = model['path']
if path not in seen_paths:
all_models.append(model)
seen_paths.add(path)
# 确保有一个默认模型:如果没有默认模型,将第一个模型设为默认
# 确保有一个默认模型:优先选择detection_model中的第一个模型
has_default = any(model.get('is_default', False) for model in all_models)
if not has_default and len(all_models) > 0:
all_models[0]['is_default'] = True
pass
# 优先选择detection_model中的模型作为默认
primary_model_found = False
for model in all_models:
if model.get('is_primary', False):
model['is_default'] = True
primary_model_found = True
break
# 如果没有detection_model中的模型,选择第一个
if not primary_model_found:
all_models[0]['is_default'] = True
return all_models
......
......@@ -11,6 +11,14 @@ from pathlib import Path
from qtpy import QtWidgets, QtCore, QtGui
from qtpy.QtCore import Qt
# 尝试导入 PyQtGraph 用于曲线显示
try:
import pyqtgraph as pg
PYQTGRAPH_AVAILABLE = True
except ImportError:
pg = None
PYQTGRAPH_AVAILABLE = False
# 导入图标工具函数
try:
from ..icons import newIcon, newButton
......@@ -36,11 +44,11 @@ except (ImportError, ValueError):
# 导入样式管理器和响应式布局
try:
from ..style_manager import FontManager, BackgroundStyleManager
from ..style_manager import FontManager, BackgroundStyleManager, TextButtonStyleManager
from ..responsive_layout import ResponsiveLayout, scale_w, scale_h
except (ImportError, ValueError):
try:
from widgets.style_manager import FontManager, BackgroundStyleManager
from widgets.style_manager import FontManager, BackgroundStyleManager, TextButtonStyleManager
from widgets.responsive_layout import ResponsiveLayout, scale_w, scale_h
except ImportError:
try:
......@@ -48,7 +56,7 @@ except (ImportError, ValueError):
from pathlib import Path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from widgets.style_manager import FontManager, BackgroundStyleManager
from widgets.style_manager import FontManager, BackgroundStyleManager, TextButtonStyleManager
from widgets.responsive_layout import ResponsiveLayout, scale_w, scale_h
except ImportError:
# 如果导入失败,创建一个简单的替代类
......@@ -96,8 +104,9 @@ class TrainingPage(QtWidgets.QWidget):
# 🔥 连接模板按钮组信号
self.template_button_group.buttonClicked.connect(self._onTemplateChecked)
self._loadBaseModelOptions() # 🔥 加载基础模型选项
self._loadTestModelOptions() # 加载测试模型选项
self._loadTestFileList() # 🔥 加载测试文件列表
# self._loadTestFileList() # 🔥 不再需要加载测试文件列表(改用浏览方式)
def _increaseFontSize(self):
"""增加日志字体大小"""
......@@ -251,6 +260,10 @@ class TrainingPage(QtWidgets.QWidget):
# 将视频面板添加到显示布局中
display_layout.addWidget(self.video_panel)
# === 添加曲线显示面板 ===
self._createCurvePanel()
display_layout.addWidget(self.curve_panel)
# 🔥 设置整体窗口的最小尺寸 - 使用响应式布局
min_w, min_h = scale_w(1000), scale_h(700)
self.setMinimumSize(min_w, min_h)
......@@ -311,24 +324,38 @@ class TrainingPage(QtWidgets.QWidget):
test_file_label.setStyleSheet("color: #495057; font-weight: bold;")
test_file_layout.addWidget(test_file_label)
self.test_file_input = QtWidgets.QComboBox()
# 测试文件路径输入框和浏览按钮
test_file_input_layout = QtWidgets.QHBoxLayout()
test_file_input_layout.setSpacing(scale_spacing(5))
# 文件路径输入框(可编辑)
self.test_file_input = QtWidgets.QLineEdit()
self.test_file_input.setPlaceholderText("选择测试图片或视频文件...")
FontManager.applyToWidget(self.test_file_input)
self.test_file_input.setStyleSheet("""
QComboBox {
QLineEdit {
padding: 6px 10px;
border: 1px solid #ced4da;
border-radius: 4px;
background-color: white;
}
QComboBox:focus {
QLineEdit:focus {
border-color: #0078d7;
outline: none;
}
QComboBox::drop-down {
border: none;
}
""")
test_file_layout.addWidget(self.test_file_input)
test_file_input_layout.addWidget(self.test_file_input)
# 浏览按钮(使用全局样式管理器)
self.test_file_browse_btn = TextButtonStyleManager.createStandardButton(
"浏览...",
parent=self,
slot=self._browseTestFile
)
self.test_file_browse_btn.setMinimumWidth(scale_w(60))
test_file_input_layout.addWidget(self.test_file_browse_btn)
test_file_layout.addLayout(test_file_input_layout)
right_control_layout.addLayout(test_file_layout)
# 添加垂直间距
......@@ -337,14 +364,32 @@ class TrainingPage(QtWidgets.QWidget):
test_button_layout = QtWidgets.QHBoxLayout()
ResponsiveLayout.apply_to_layout(test_button_layout, base_spacing=10, base_margins=0)
self.start_annotation_btn = QtWidgets.QPushButton("开始标注")
# 使用全局样式管理器创建测试按钮
self.start_annotation_btn = TextButtonStyleManager.createStandardButton(
"开始标注",
parent=self
)
self.start_annotation_btn.setMinimumWidth(scale_w(80))
test_button_layout.addWidget(self.start_annotation_btn)
self.start_test_btn = QtWidgets.QPushButton("开始测试")
self.start_test_btn = TextButtonStyleManager.createStandardButton(
"开始测试",
parent=self
)
self.start_test_btn.setMinimumWidth(scale_w(80))
test_button_layout.addWidget(self.start_test_btn)
# 查看曲线按钮(使用全局样式管理器)
self.view_curve_btn = TextButtonStyleManager.createStandardButton(
"查看曲线",
parent=self,
slot=self._onViewCurveClicked
)
self.view_curve_btn.setMinimumWidth(scale_w(80))
self.view_curve_btn.setEnabled(False) # 初始状态禁用
self.view_curve_btn.setToolTip("测试完成后可查看曲线结果")
test_button_layout.addWidget(self.view_curve_btn)
# 将按钮布局添加到主布局
right_control_layout.addLayout(test_button_layout)
......@@ -389,29 +434,37 @@ class TrainingPage(QtWidgets.QWidget):
FontManager.applyToWidget(self.auto_scroll_check)
log_header.addWidget(self.auto_scroll_check)
# 清空日志按钮(使用系统默认样式 + 响应式布局)
clear_log_btn = QtWidgets.QPushButton("清空日志")
# 清空日志按钮(使用全局样式管理器)
clear_log_btn = TextButtonStyleManager.createStandardButton(
"清空日志",
parent=self,
slot=self.clearLog
)
clear_log_btn.setMinimumWidth(scale_w(60))
# 🔥 添加字体大小调整按钮
font_size_label = QtWidgets.QLabel("字体:")
# 使用系统默认样式
FontManager.applyToWidget(font_size_label)
log_header.addWidget(font_size_label)
self.font_decrease_btn = QtWidgets.QPushButton("-")
# 字体调整按钮(使用全局样式管理器)
self.font_decrease_btn = TextButtonStyleManager.createStandardButton(
"-",
parent=self,
slot=self._decreaseFontSize
)
btn_size = scale_w(20) # 响应式按钮尺寸
self.font_decrease_btn.setFixedSize(btn_size, btn_size)
# 使用系统默认样式
self.font_decrease_btn.clicked.connect(self._decreaseFontSize)
log_header.addWidget(self.font_decrease_btn)
self.font_increase_btn = QtWidgets.QPushButton("+")
self.font_increase_btn = TextButtonStyleManager.createStandardButton(
"+",
parent=self,
slot=self._increaseFontSize
)
self.font_increase_btn.setFixedSize(btn_size, btn_size)
# 使用系统默认样式
self.font_increase_btn.clicked.connect(self._increaseFontSize)
log_header.addWidget(self.font_increase_btn)
clear_log_btn.clicked.connect(self.clearLog)
log_header.addWidget(clear_log_btn)
right_layout.addLayout(log_header)
......@@ -467,6 +520,199 @@ class TrainingPage(QtWidgets.QWidget):
pass
def _createCurvePanel(self):
"""创建曲线显示面板"""
self.curve_panel = QtWidgets.QWidget()
curve_layout = QtWidgets.QVBoxLayout(self.curve_panel)
curve_layout.setContentsMargins(5, 5, 5, 5)
curve_layout.setSpacing(5)
# 曲线面板标题
curve_title_layout = QtWidgets.QHBoxLayout()
curve_title = QtWidgets.QLabel("测试结果曲线")
curve_title.setStyleSheet("color: #495057; font-weight: bold; font-size: 12pt;")
FontManager.applyToWidget(curve_title, weight=FontManager.WEIGHT_BOLD)
curve_title_layout.addWidget(curve_title)
# 添加清空曲线按钮(使用全局样式管理器)
self.clear_curve_btn = TextButtonStyleManager.createStandardButton(
"清空曲线",
parent=self,
slot=self._clearCurve
)
self.clear_curve_btn.setMinimumWidth(scale_w(80))
curve_title_layout.addStretch()
curve_title_layout.addWidget(self.clear_curve_btn)
curve_layout.addLayout(curve_title_layout)
# 检查 PyQtGraph 是否可用
if PYQTGRAPH_AVAILABLE:
# 创建 PyQtGraph 绘图控件
self.curve_plot_widget = pg.PlotWidget()
self.curve_plot_widget.setBackground('#ffffff')
self.curve_plot_widget.showGrid(x=True, y=True, alpha=0.3)
# 设置坐标轴标签
self.curve_plot_widget.setLabel('left', '液位高度', units='mm')
self.curve_plot_widget.setLabel('bottom', '帧序号')
self.curve_plot_widget.setTitle('液位检测曲线', color='#495057', size='12pt')
# 添加图例
self.curve_plot_widget.addLegend()
# 存储曲线数据
self.curve_data_x = [] # X轴数据(帧序号)
self.curve_data_y = [] # Y轴数据(液位高度)
self.curve_line = None # 曲线对象
curve_layout.addWidget(self.curve_plot_widget)
else:
# PyQtGraph 不可用,显示提示信息
placeholder = QtWidgets.QLabel(
"曲线显示功能需要 PyQtGraph 库\n\n"
"请安装: pip install pyqtgraph"
)
placeholder.setAlignment(Qt.AlignCenter)
placeholder.setStyleSheet("color: #999; font-size: 11pt; padding: 50px;")
FontManager.applyToWidget(placeholder)
curve_layout.addWidget(placeholder)
def _clearCurve(self):
"""清空曲线数据并隐藏曲线面板"""
if PYQTGRAPH_AVAILABLE and hasattr(self, 'curve_plot_widget'):
# 清空数据
self.curve_data_x = []
self.curve_data_y = []
# 清空曲线
if self.curve_line:
self.curve_plot_widget.removeItem(self.curve_line)
self.curve_line = None
print("[曲线] 已清空曲线数据")
# 自动隐藏曲线面板,返回到初始显示状态
self.hideCurvePanel()
print("[曲线] 曲线面板已隐藏,返回初始显示状态")
def addCurvePoint(self, frame_index, height_mm):
"""添加曲线数据点
Args:
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
if not PYQTGRAPH_AVAILABLE or not hasattr(self, 'curve_plot_widget'):
return
try:
# 添加数据点
self.curve_data_x.append(frame_index)
self.curve_data_y.append(height_mm)
# 如果曲线不存在,创建曲线
if self.curve_line is None:
self.curve_line = self.curve_plot_widget.plot(
self.curve_data_x,
self.curve_data_y,
pen=pg.mkPen(color='#1f77b4', width=2),
name='液位高度'
)
else:
# 更新曲线数据
self.curve_line.setData(self.curve_data_x, self.curve_data_y)
except Exception as e:
print(f"[曲线] 添加数据点失败: {e}")
def showCurvePanel(self):
"""显示曲线面板"""
if hasattr(self, 'display_layout'):
# 切换到曲线面板(索引3:hint, display_panel, video_panel, curve_panel)
self.display_layout.setCurrentIndex(3)
def hideCurvePanel(self):
"""隐藏曲线面板,返回到显示面板"""
if hasattr(self, 'display_layout'):
self.display_layout.setCurrentIndex(1) # 显示 display_panel
def saveCurveData(self, csv_path):
"""保存曲线数据为CSV文件
Args:
csv_path: CSV文件保存路径
Returns:
bool: 是否成功保存
"""
if not PYQTGRAPH_AVAILABLE or not hasattr(self, 'curve_data_x'):
return False
try:
import csv
# 检查是否有数据
if len(self.curve_data_x) == 0:
print("[曲线保存] 没有曲线数据可保存")
return False
# 写入CSV文件
with open(csv_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# 写入表头
writer.writerow(['帧序号', '液位高度(mm)'])
# 写入数据
for x, y in zip(self.curve_data_x, self.curve_data_y):
writer.writerow([x, y])
print(f"[曲线保存] CSV数据已保存: {csv_path}")
print(f"[曲线保存] 共保存 {len(self.curve_data_x)} 个数据点")
return True
except Exception as e:
print(f"[曲线保存] 保存CSV失败: {e}")
import traceback
traceback.print_exc()
return False
def saveCurveImage(self, image_path):
"""保存曲线为图片文件
Args:
image_path: 图片文件保存路径(支持 .png, .jpg, .svg等格式)
Returns:
bool: 是否成功保存
"""
if not PYQTGRAPH_AVAILABLE or not hasattr(self, 'curve_plot_widget'):
return False
try:
# 检查是否有数据
if len(self.curve_data_x) == 0:
print("[曲线保存] 没有曲线数据可保存")
return False
# 使用PyQtGraph的导出功能
exporter = pg.exporters.ImageExporter(self.curve_plot_widget.plotItem)
# 设置导出参数
exporter.parameters()['width'] = 1200 # 设置宽度
exporter.parameters()['height'] = 600 # 设置高度
# 导出图片
exporter.export(image_path)
print(f"[曲线保存] 曲线图片已保存: {image_path}")
return True
except Exception as e:
print(f"[曲线保存] 保存图片失败: {e}")
import traceback
traceback.print_exc()
return False
def _createParametersGroup(self):
"""创建参数配置组"""
group = QtWidgets.QGroupBox("")
......@@ -502,38 +748,53 @@ class TrainingPage(QtWidgets.QWidget):
# 创建表单布局
layout = QtWidgets.QFormLayout()
from ..responsive_layout import scale_spacing
layout.setSpacing(scale_spacing(8))
layout.setSpacing(8)
layout.setContentsMargins(0, 0, 0, 0)
layout.setLabelAlignment(Qt.AlignRight | Qt.AlignVCenter)
# 🔥 删除CSS样式表,改为纯控件方式,由字体管理器统一管理
# 基础模型路径
model_layout = QtWidgets.QHBoxLayout()
from ..responsive_layout import scale_spacing
model_layout.setSpacing(scale_spacing(6))
self.base_model_edit = QtWidgets.QLineEdit()
self.base_model_edit.setPlaceholderText("选择基础模型文件 (.pt)")
self.browse_model_btn = QtWidgets.QPushButton("浏览...")
self.browse_model_btn.setFixedWidth(scale_w(80))
self.browse_model_btn.clicked.connect(self._browseModel)
model_layout.addWidget(self.base_model_edit, 1)
model_layout.addWidget(self.browse_model_btn)
layout.addRow("基础模型:", model_layout)
# 数据集路径 (字段名必须是 save_liquid_data_path_edit 以匹配训练处理器)
data_layout = QtWidgets.QHBoxLayout()
from ..responsive_layout import scale_spacing
data_layout.setSpacing(scale_spacing(6))
# 基础模型选择(下拉菜单)
self.base_model_combo = QtWidgets.QComboBox()
self.base_model_combo.setPlaceholderText("请选择基础模型")
self.base_model_combo.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
layout.addRow("基础模型:", self.base_model_combo)
# 数据集文件夹选择(统一格式,支持多个文件夹)
dataset_layout = QtWidgets.QHBoxLayout()
dataset_layout.setSpacing(8)
dataset_layout.setContentsMargins(0, 0, 0, 0)
# 数据集路径显示文本框(可编辑,支持多个路径用分号分隔)
self.dataset_paths_edit = QtWidgets.QLineEdit()
self.dataset_paths_edit.setPlaceholderText("")
self.dataset_paths_edit.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed)
# 移除自定义样式,使用全局字体管理器统一管理
self.dataset_paths_edit.setStyleSheet("")
# 连接文本变化信号,实时更新内部数据集列表
self.dataset_paths_edit.textChanged.connect(self._onDatasetPathsChanged)
dataset_layout.addWidget(self.dataset_paths_edit)
# 浏览按钮(使用全局样式管理器)
self.btn_browse_datasets = TextButtonStyleManager.createStandardButton(
"浏览...",
parent=self,
slot=self._onBrowseDatasets
)
self.btn_browse_datasets.setFixedWidth(100)
self.btn_browse_datasets.setSizePolicy(QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed)
dataset_layout.addWidget(self.btn_browse_datasets)
layout.addRow("数据集:", dataset_layout)
# 保持内部数据集文件夹列表(隐藏,用于兼容现有逻辑)
self.dataset_folders_list = QtWidgets.QListWidget()
self.dataset_folders_list.setVisible(False) # 隐藏,仅用于内部数据管理
# 保留旧的字段名以保持向后兼容(用于获取数据集路径)
# 现在它将存储用分号分隔的多个文件夹路径
self.save_liquid_data_path_edit = QtWidgets.QLineEdit()
self.save_liquid_data_path_edit.setPlaceholderText("选择数据集配置文件 (data.yaml)")
self.save_liquid_data_path_edit.setText("database/dataset/data.yaml") # 默认路径
self.browse_data_btn = QtWidgets.QPushButton("浏览...")
self.browse_data_btn.setFixedWidth(scale_w(80))
self.browse_data_btn.clicked.connect(self._browseDataset)
data_layout.addWidget(self.save_liquid_data_path_edit, 1)
data_layout.addWidget(self.browse_data_btn)
layout.addRow("数据集:", data_layout)
self.save_liquid_data_path_edit.setVisible(False) # 隐藏,仅用于数据传递
# 实验名称
self.exp_name_edit = QtWidgets.QLineEdit()
......@@ -586,6 +847,18 @@ class TrainingPage(QtWidgets.QWidget):
self.optimizer_combo.addItems(["SGD", "Adam", "AdamW"])
layout.addRow("优化器:", self.optimizer_combo)
# 训练笔记按钮(使用全局样式管理器)
self.training_notes_btn = TextButtonStyleManager.createStandardButton(
"训练笔记",
parent=self,
slot=self._openNotesDialog
)
self.training_notes_btn.setToolTip("点击打开训练笔记编辑窗口")
layout.addRow("训练笔记:", self.training_notes_btn)
# 内部存储笔记内容的变量
self._training_notes_content = ""
# 高级选项分隔
separator = QtWidgets.QLabel()
separator.setStyleSheet("border-top: 1px solid #dee2e6; margin: 5px 0;")
......@@ -646,12 +919,18 @@ class TrainingPage(QtWidgets.QWidget):
# 创建别名以匹配训练处理器期望的名称
self.train_status_label = self.status_label
# 控制按钮(使用Qt默认样式 + 响应式布局)
self.start_train_btn = QtWidgets.QPushButton("开始升级")
# 控制按钮(使用全局样式管理器)
self.start_train_btn = TextButtonStyleManager.createStandardButton(
"开始升级",
parent=self
)
self.start_train_btn.setMinimumWidth(scale_w(80))
control_layout.addWidget(self.start_train_btn)
self.stop_train_btn = QtWidgets.QPushButton("停止升级")
self.stop_train_btn = TextButtonStyleManager.createStandardButton(
"停止升级",
parent=self
)
self.stop_train_btn.setMinimumWidth(scale_w(80))
self.stop_train_btn.setEnabled(False)
control_layout.addWidget(self.stop_train_btn)
......@@ -663,7 +942,7 @@ class TrainingPage(QtWidgets.QWidget):
# 🔥 应用全局字体管理器到所有文本框和控件
if FontManager:
# 应用到所有QLineEdit
FontManager.applyToWidget(self.base_model_edit)
FontManager.applyToWidget(self.dataset_paths_edit)
FontManager.applyToWidget(self.save_liquid_data_path_edit)
FontManager.applyToWidget(self.exp_name_edit)
......@@ -674,6 +953,7 @@ class TrainingPage(QtWidgets.QWidget):
FontManager.applyToWidget(self.workers_spin)
# 应用到所有QComboBox
FontManager.applyToWidget(self.base_model_combo)
FontManager.applyToWidget(self.device_combo)
FontManager.applyToWidget(self.optimizer_combo)
......@@ -686,39 +966,181 @@ class TrainingPage(QtWidgets.QWidget):
FontManager.applyToWidget(self.checkbox_template_2)
FontManager.applyToWidget(self.checkbox_template_3)
# 应用到所有QPushButton
FontManager.applyToWidget(self.browse_model_btn)
FontManager.applyToWidget(self.browse_data_btn)
# 应用到状态标签(非按钮样式管理器创建的控件)
FontManager.applyToWidget(self.status_label)
FontManager.applyToWidget(self.start_train_btn)
FontManager.applyToWidget(self.stop_train_btn)
# 应用到数据集列表
FontManager.applyToWidget(self.dataset_folders_list)
# 应用到笔记按钮
FontManager.applyToWidget(self.training_notes_btn)
# 应用到整个group(包括标签)
FontManager.applyToWidgetRecursive(group)
return group
def _browseModel(self):
"""浏览模型文件"""
file_path, _ = QtWidgets.QFileDialog.getOpenFileName(
def _onBrowseDatasets(self):
"""浏览数据集文件夹(支持多选)"""
# 使用简单的单选文件夹对话框,多次选择来实现多选效果
folder_path = QtWidgets.QFileDialog.getExistingDirectory(
self,
"选择基础模型",
"database/model",
"模型文件 (*.pt *.dat);;所有文件 (*.*)"
"选择数据集文件夹",
"database/dataset" if os.path.exists("database/dataset") else ""
)
if file_path:
self.base_model_edit.setText(file_path)
if folder_path:
# 获取当前已有的文件夹路径
current_paths = self.dataset_paths_edit.text().strip()
existing_folders = [p.strip() for p in current_paths.split(';') if p.strip()] if current_paths else []
# 检查是否已经添加过这个文件夹
if folder_path not in existing_folders:
# 添加新文件夹
if existing_folders:
# 如果已有文件夹,用分号连接
new_paths = current_paths + ';' + folder_path
else:
# 如果是第一个文件夹
new_paths = folder_path
self.dataset_paths_edit.setText(new_paths)
print(f"[TrainingPage] 添加数据集文件夹: {folder_path}")
else:
# 使用style_manager中的对话框管理器显示提示
try:
from ..style_manager import DialogManager
DialogManager.show_information(
self, "提示",
f"文件夹已存在:\n{folder_path}"
)
except ImportError:
QtWidgets.QMessageBox.information(
self, "提示",
f"文件夹已存在:\n{folder_path}"
)
def _browseDataset(self):
"""浏览数据集文件"""
file_path, _ = QtWidgets.QFileDialog.getOpenFileName(
self,
"选择数据集配置",
"database/dataset",
"YAML文件 (*.yaml *.yml);;所有文件 (*.*)"
def _onDatasetPathsChanged(self):
"""数据集路径文本变化时的处理"""
# 更新内部数据集列表和隐藏的路径字段
paths_text = self.dataset_paths_edit.text().strip()
folders = [p.strip() for p in paths_text.split(';') if p.strip()]
# 更新内部列表(用于兼容现有逻辑)
self.dataset_folders_list.clear()
for folder in folders:
self.dataset_folders_list.addItem(folder)
# 更新隐藏的路径字段
self.save_liquid_data_path_edit.setText(paths_text)
# 调试信息
print(f"[TrainingPage] 数据集路径更新: {len(folders)} 个文件夹")
for i, folder in enumerate(folders):
print(f" [{i+1}] {folder}")
def _addDatasetFolder(self):
"""添加数据集文件夹(保留兼容性)"""
self._onBrowseDatasets()
def _removeSelectedDatasets(self):
"""删除选中的数据集文件夹"""
selected_items = self.dataset_folders_list.selectedItems()
if not selected_items:
QtWidgets.QMessageBox.information(self, "提示", "请先选择要删除的文件夹")
return
for item in selected_items:
row = self.dataset_folders_list.row(item)
self.dataset_folders_list.takeItem(row)
self._updateDatasetPath()
def _clearAllDatasets(self):
"""清空所有数据集文件夹"""
if self.dataset_folders_list.count() == 0:
return
reply = QtWidgets.QMessageBox.question(
self,
"确认清空",
f"确定要清空所有 {self.dataset_folders_list.count()} 个数据集文件夹吗?",
QtWidgets.QMessageBox.Yes | QtWidgets.QMessageBox.No,
QtWidgets.QMessageBox.No
)
if file_path:
self.save_liquid_data_path_edit.setText(file_path)
if reply == QtWidgets.QMessageBox.Yes:
self.dataset_folders_list.clear()
self._updateDatasetPath()
def _updateDatasetPath(self):
"""更新隐藏的数据集路径字段(用分号分隔多个文件夹)"""
folders = [self.dataset_folders_list.item(i).text()
for i in range(self.dataset_folders_list.count())]
# 使用分号分隔多个文件夹路径
self.save_liquid_data_path_edit.setText(';'.join(folders))
def getDatasetFolders(self):
"""获取所有数据集文件夹路径列表"""
paths_text = self.dataset_paths_edit.text().strip()
return [p.strip() for p in paths_text.split(';') if p.strip()]
def getTrainingNotes(self):
"""获取训练笔记内容"""
return self._training_notes_content.strip()
def setTrainingNotes(self, notes):
"""设置训练笔记内容"""
self._training_notes_content = notes if notes else ""
self._updateNotesButtonText()
def clearTrainingNotes(self):
"""清空训练笔记(不弹确认框)"""
self._training_notes_content = ""
self._updateNotesButtonText()
def _openNotesDialog(self):
"""打开训练笔记编辑对话框"""
dialog = TrainingNotesDialog(self._training_notes_content, self)
if dialog.exec_() == QtWidgets.QDialog.Accepted:
self._training_notes_content = dialog.getNotesContent()
self._updateNotesButtonText()
def _updateNotesButtonText(self):
"""更新笔记按钮的显示文本"""
if self._training_notes_content.strip():
# 显示笔记的前几个字符
preview = self._training_notes_content.strip()[:8] # 减少字符数以适应按钮宽度
if len(self._training_notes_content.strip()) > 8:
preview += "..."
new_text = f"训练笔记 ({preview})"
# 使用全局样式管理器更新按钮文本和大小
TextButtonStyleManager.updateButtonText(self.training_notes_btn, new_text)
# 添加有内容的视觉提示(保持全局样式基础上的微调)
current_style = self.training_notes_btn.styleSheet()
self.training_notes_btn.setStyleSheet(current_style + """
QPushButton {
background-color: #e3f2fd;
border: 1px solid #2196f3;
}
""")
else:
# 使用全局样式管理器重置按钮
TextButtonStyleManager.updateButtonText(self.training_notes_btn, "训练笔记")
def enableNotesButtons(self):
"""启用笔记按钮(训练完成后调用)"""
# 训练笔记按钮始终可用,无需特殊处理
pass
def disableNotesButtons(self):
"""禁用笔记按钮(训练开始前调用)"""
# 训练笔记按钮始终可用,无需特殊处理
pass
@QtCore.Slot(str)
def appendLog(self, text):
......@@ -898,10 +1320,11 @@ class TrainingPage(QtWidgets.QWidget):
self.stop_train_btn.setEnabled(is_training)
# 禁用参数输入
self.base_model_edit.setEnabled(not is_training)
self.browse_model_btn.setEnabled(not is_training)
self.save_liquid_data_path_edit.setEnabled(not is_training)
self.browse_data_btn.setEnabled(not is_training)
self.base_model_combo.setEnabled(not is_training)
self.dataset_folders_list.setEnabled(not is_training)
self.add_dataset_btn.setEnabled(not is_training)
self.remove_dataset_btn.setEnabled(not is_training)
self.clear_dataset_btn.setEnabled(not is_training)
self.exp_name_edit.setEnabled(not is_training)
self.epochs_spin.setEnabled(not is_training)
self.batch_spin.setEnabled(not is_training)
......@@ -915,8 +1338,59 @@ class TrainingPage(QtWidgets.QWidget):
def showEvent(self, event):
"""页面显示时刷新模型列表和测试文件列表(确保与模型集管理页面同步)"""
super(TrainingPage, self).showEvent(event)
self._loadBaseModelOptions() # 🔥 加载基础模型列表
self._loadTestModelOptions()
self._loadTestFileList() # 🔥 刷新测试文件列表
# self._loadTestFileList() # 🔥 不再需要刷新测试文件列表(改用浏览方式)
def _loadBaseModelOptions(self):
"""从模型集管理页面加载基础模型选项"""
# 清空现有选项
self.base_model_combo.clear()
try:
# 尝试从父窗口获取模型集页面
if hasattr(self._parent, 'modelSetPage'):
model_set_page = self._parent.modelSetPage
# 获取模型参数字典
if hasattr(model_set_page, '_model_params'):
model_params = model_set_page._model_params
if not model_params:
self.base_model_combo.addItem("未找到模型", None)
return
# 获取默认模型
default_model = None
if hasattr(model_set_page, '_current_default_model'):
default_model = model_set_page._current_default_model
# 添加所有模型到下拉框
default_index = 0
for idx, (model_name, params) in enumerate(model_params.items()):
model_path = params.get('path', '')
# 构建显示名称
display_name = model_name
if model_name == default_model:
display_name = f"{model_name} (默认)"
default_index = idx
# 添加到下拉框,使用模型路径作为数据
self.base_model_combo.addItem(display_name, model_path)
# 设置默认选择
self.base_model_combo.setCurrentIndex(default_index)
else:
self.base_model_combo.addItem("模型集页面未初始化", None)
else:
self.base_model_combo.addItem("未找到模型集页面", None)
except Exception as e:
print(f"[基础模型] 加载失败: {e}")
import traceback
traceback.print_exc()
self.base_model_combo.addItem("加载失败", None)
def _loadTestModelOptions(self):
"""加载测试模型选项(从 train_model 文件夹读取)"""
......@@ -1118,91 +1592,10 @@ class TrainingPage(QtWidgets.QWidget):
return models
def _loadTestFileList(self):
"""加载测试文件列表到下拉框"""
try:
import os
from pathlib import Path
# 🔥 改进:使用多种方式获取project_root
project_root = None
try:
from ...database.config import get_project_root
project_root = get_project_root()
# print(f"[测试文件] 使用相对导入获取project_root: {project_root}")
except (ImportError, ValueError):
# 相对导入失败是预期行为(作为主程序运行时),静默处理
try:
from database.config import get_project_root
project_root = get_project_root()
except (ImportError, ValueError):
# 备选方案:使用当前文件的父目录
current_file = Path(__file__).resolve()
project_root = current_file.parent.parent.parent
if not project_root:
raise RuntimeError("无法获取project_root")
test_file_dir = os.path.join(str(project_root), 'database', 'model_test_file')
# 清空下拉框
self.test_file_input.clear()
if not os.path.exists(test_file_dir):
self.test_file_input.addItem("(目录不存在)")
return
# 扫描目录
items = []
try:
dir_items = os.listdir(test_file_dir)
video_extensions = ['.mp4', '.avi', '.mov', '.mkv', '.flv', '.wmv']
image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp']
for item in dir_items:
item_path = os.path.join(test_file_dir, item)
# 检查是否为视频文件
if os.path.isfile(item_path):
if any(item.lower().endswith(ext) for ext in video_extensions):
items.append((item, item_path))
# 检查是否为文件夹
elif os.path.isdir(item_path):
# 检查文件夹内是否有图片或视频
has_images = False
has_videos = False
try:
folder_items = os.listdir(item_path)
for file in folder_items:
file_lower = file.lower()
if any(file_lower.endswith(ext) for ext in image_extensions):
has_images = True
break
elif any(file_lower.endswith(ext) for ext in video_extensions):
has_videos = True
except Exception as folder_e:
pass
# 🔥 改进:文件夹包含图片或视频都添加
if has_images or has_videos:
items.append((f"📁 {item}", item_path))
except Exception as e:
import traceback
traceback.print_exc()
# 添加到下拉框
if items:
for display_name, full_path in sorted(items):
self.test_file_input.addItem(display_name, full_path)
else:
self.test_file_input.addItem("(未找到测试文件)")
except Exception as e:
import traceback
traceback.print_exc()
self.test_file_input.addItem("(加载失败)")
"""加载测试文件列表(保留方法以保持向后兼容,但现在使用浏览方式)"""
# 该方法现已改为使用浏览按钮选择文件,不再从固定目录加载
# 保留此方法以防其他代码调用
pass
def _getRememberedTestModel(self, project_root):
"""从配置文件获取记忆的测试模型路径"""
......@@ -1254,8 +1647,8 @@ class TrainingPage(QtWidgets.QWidget):
def getTestFilePath(self):
"""获取选中的测试文件路径"""
# 🔥 改为从QComboBox获取数据
return self.test_file_input.currentData() or ""
# 从 QLineEdit 获取文件路径
return self.test_file_input.text().strip()
def isTestingInProgress(self):
"""检查是否正在测试中"""
......@@ -1267,27 +1660,14 @@ class TrainingPage(QtWidgets.QWidget):
if is_testing:
# 切换为"停止测试"状态
self.start_test_btn.setText("停止测试")
# 使用红色样式表示危险操作
self.start_test_btn.setStyleSheet("""
QPushButton {
background-color: #dc3545;
color: white;
border: none;
padding: 8px 16px;
border-radius: 4px;
font-weight: bold;
min-width: 100px;
}
QPushButton:hover {
background-color: #c82333;
}
""")
TextButtonStyleManager.updateButtonText(self.start_test_btn, "停止测试")
# 使用全局样式管理器的危险按钮样式
TextButtonStyleManager.applyDangerStyle(self.start_test_btn)
else:
# 切换为"开始测试"状态
self.start_test_btn.setText("开始测试")
# 恢复默认样式
self.start_test_btn.setStyleSheet("")
TextButtonStyleManager.updateButtonText(self.start_test_btn, "开始测试")
# 恢复标准样式
TextButtonStyleManager.applyStandardStyle(self.start_test_btn)
def _onTemplateChecked(self, button):
"""处理模板复选框选中事件"""
......@@ -1326,8 +1706,14 @@ class TrainingPage(QtWidgets.QWidget):
try:
# 基础模型
if 'model' in config:
self.base_model_edit.setText(str(config['model']))
print(f"[模板] 设置基础模型: {config['model']}")
model_path = str(config['model'])
# 在下拉菜单中查找匹配的模型
for i in range(self.base_model_combo.count()):
item_data = self.base_model_combo.itemData(i)
if item_data and item_data == model_path:
self.base_model_combo.setCurrentIndex(i)
print(f"[模板] 设置基础模型: {model_path}")
break
# 数据集配置
if 'data' in config:
......@@ -1392,6 +1778,134 @@ class TrainingPage(QtWidgets.QWidget):
import traceback
traceback.print_exc()
def _browseTestFile(self):
"""浏览选择测试文件(支持图片和视频)"""
try:
# 定义支持的文件类型
image_formats = "图片文件 (*.jpg *.jpeg *.png *.bmp *.tiff *.webp)"
video_formats = "视频文件 (*.mp4 *.avi *.mov *.mkv *.flv *.wmv)"
all_formats = "所有支持的文件 (*.jpg *.jpeg *.png *.bmp *.tiff *.webp *.mp4 *.avi *.mov *.mkv *.flv *.wmv)"
# 构建文件过滤器
file_filter = f"{all_formats};;{image_formats};;{video_formats};;所有文件 (*.*)"
# 打开文件选择对话框
file_path, _ = QtWidgets.QFileDialog.getOpenFileName(
self,
"选择测试图片或视频文件",
"", # 默认目录为空,使用系统默认
file_filter
)
# 如果用户选择了文件,则设置到输入框
if file_path:
self.test_file_input.setText(file_path)
print(f"[测试文件] 已选择: {file_path}")
except Exception as e:
import traceback
traceback.print_exc()
QtWidgets.QMessageBox.warning(
self,
"文件选择失败",
f"选择测试文件时发生错误:\n{str(e)}"
)
def _onViewCurveClicked(self):
"""查看曲线按钮点击处理"""
try:
# 检查是否有曲线数据
if not hasattr(self, 'curve_data_x') or not hasattr(self, 'curve_data_y'):
self._showNoCurveMessage()
return
if len(self.curve_data_x) == 0:
self._showNoCurveMessage()
return
# 切换到曲线面板显示
self.showCurvePanel()
# 显示曲线信息提示
data_count = len(self.curve_data_x)
if data_count == 1:
# 图片测试
liquid_level = self.curve_data_y[0]
QtWidgets.QMessageBox.information(
self,
"曲线信息",
f"图片测试结果:\n液位高度: {liquid_level:.1f} mm\n\n"
f"曲线已显示在左侧面板中。"
)
else:
# 视频测试
min_level = min(self.curve_data_y)
max_level = max(self.curve_data_y)
avg_level = sum(self.curve_data_y) / len(self.curve_data_y)
QtWidgets.QMessageBox.information(
self,
"曲线信息",
f"视频测试结果:\n"
f"数据点数: {data_count} 个\n"
f"液位范围: {min_level:.1f} - {max_level:.1f} mm\n"
f"平均液位: {avg_level:.1f} mm\n\n"
f"曲线已显示在左侧面板中。"
)
except Exception as e:
print(f"[查看曲线] 显示曲线失败: {e}")
QtWidgets.QMessageBox.warning(
self,
"显示失败",
f"显示曲线时发生错误:\n{str(e)}"
)
def _showNoCurveMessage(self):
"""显示无曲线数据的提示"""
QtWidgets.QMessageBox.information(
self,
"无曲线数据",
"当前没有可显示的曲线数据。\n\n"
"请先进行模型测试:\n"
"1. 选择测试模型\n"
"2. 选择测试文件\n"
"3. 点击\"开始标注\"\n"
"4. 点击\"开始测试\"\n\n"
"测试完成后即可查看曲线结果。"
)
def enableViewCurveButton(self):
"""启用查看曲线按钮(测试完成后调用)"""
try:
self.view_curve_btn.setEnabled(True)
self.view_curve_btn.setToolTip("点击查看测试结果曲线")
# 检查曲线数据类型并更新按钮文本
if hasattr(self, 'curve_data_x') and len(self.curve_data_x) > 0:
data_count = len(self.curve_data_x)
if data_count == 1:
self.view_curve_btn.setText("查看曲线(图)")
else:
self.view_curve_btn.setText("查看曲线(视)")
else:
self.view_curve_btn.setText("查看曲线")
print(f"[查看曲线] 按钮已启用,数据点数: {len(self.curve_data_x) if hasattr(self, 'curve_data_x') else 0}")
except Exception as e:
print(f"[查看曲线] 启用按钮失败: {e}")
def disableViewCurveButton(self):
"""禁用查看曲线按钮(测试开始前调用)"""
try:
self.view_curve_btn.setEnabled(False)
self.view_curve_btn.setText("查看曲线")
self.view_curve_btn.setToolTip("测试完成后可查看曲线结果")
print(f"[查看曲线] 按钮已禁用")
except Exception as e:
print(f"[查看曲线] 禁用按钮失败: {e}")
def getTemplateConfig(self):
"""获取当前选中的模板配置名称"""
if self.checkbox_template_1.isChecked():
......@@ -1414,3 +1928,148 @@ if __name__ == '__main__':
page.show()
sys.exit(app.exec_())
class TrainingNotesDialog(QtWidgets.QDialog):
"""训练笔记编辑对话框"""
def __init__(self, initial_content="", parent=None):
super().__init__(parent)
self.setWindowTitle("训练笔记编辑")
self.setMinimumSize(500, 400)
self.resize(600, 500)
# 设置窗口图标(如果有的话)
self.setWindowFlags(self.windowFlags() & ~QtCore.Qt.WindowContextHelpButtonHint)
# 应用全局样式管理器
from ..style_manager import FontManager, BackgroundStyleManager
BackgroundStyleManager.applyToWidget(self)
self._setupUI()
self.text_edit.setPlainText(initial_content)
# 应用全局字体管理器到整个对话框
FontManager.applyToDialog(self)
def _setupUI(self):
"""设置UI界面"""
layout = QtWidgets.QVBoxLayout(self)
layout.setSpacing(10)
layout.setContentsMargins(15, 15, 15, 15)
# 标题标签(使用全局字体管理器)
from ..style_manager import FontManager
title_label = QtWidgets.QLabel("训练笔记")
title_label.setFont(FontManager.getTitleFont())
title_label.setStyleSheet("""
QLabel {
color: #333;
margin-bottom: 5px;
}
""")
layout.addWidget(title_label)
# 说明文字(使用全局字体管理器)
info_label = QtWidgets.QLabel("在此记录本次训练的相关信息,如数据集变化、参数调整原因、预期效果等...")
info_label.setFont(FontManager.getSmallFont())
info_label.setStyleSheet("""
QLabel {
color: #666;
margin-bottom: 10px;
}
""")
info_label.setWordWrap(True)
layout.addWidget(info_label)
# 文本编辑区域(使用全局字体管理器)
self.text_edit = QtWidgets.QTextEdit()
self.text_edit.setFont(FontManager.getMediumFont())
self.text_edit.setStyleSheet("""
QTextEdit {
border: 1px solid #ccc;
background-color: white;
padding: 10px;
line-height: 1.4;
}
""")
self.text_edit.setPlaceholderText(
"示例内容:\n"
"• 数据集:新增了100张液位图片\n"
"• 参数调整:学习率从0.01调整为0.005\n"
"• 预期效果:提高小液位目标的检测精度\n"
"• 其他备注:..."
)
layout.addWidget(self.text_edit)
# 字符计数标签(使用全局字体管理器)
self.char_count_label = QtWidgets.QLabel("字符数: 0")
self.char_count_label.setFont(FontManager.getSmallFont())
self.char_count_label.setStyleSheet("color: #666;")
self.char_count_label.setAlignment(QtCore.Qt.AlignRight)
layout.addWidget(self.char_count_label)
# 按钮区域(使用全局样式管理器)
button_layout = QtWidgets.QHBoxLayout()
button_layout.addStretch()
# 清空按钮
clear_btn = TextButtonStyleManager.createStandardButton("清空", self, self._clearText)
button_layout.addWidget(clear_btn)
# 取消按钮
cancel_btn = TextButtonStyleManager.createStandardButton("取消", self, self.reject)
button_layout.addWidget(cancel_btn)
# 保存按钮(使用全局样式管理器创建,然后添加特殊样式)
save_btn = TextButtonStyleManager.createStandardButton("保存", self, self.accept)
save_btn.setDefault(True)
save_btn.setStyleSheet("""
QPushButton {
background-color: #2196f3;
color: white;
border: none;
padding: 8px;
font-weight: bold;
}
QPushButton:hover {
background-color: #1976d2;
}
""")
button_layout.addWidget(save_btn)
layout.addLayout(button_layout)
# 连接信号
self.text_edit.textChanged.connect(self._updateCharCount)
self._updateCharCount()
def _clearText(self):
"""清空文本"""
if self.text_edit.toPlainText().strip():
from ..style_manager import DialogManager
if DialogManager.show_question_warning(
self,
"确认清空",
"确定要清空所有笔记内容吗?",
"是", "否"
):
self.text_edit.clear()
def _updateCharCount(self):
"""更新字符计数"""
text = self.text_edit.toPlainText()
char_count = len(text)
self.char_count_label.setText(f"字符数: {char_count}")
# 字符数过多时显示警告颜色(保持全局字体设置)
if char_count > 1000:
self.char_count_label.setStyleSheet("color: #f44336;")
elif char_count > 500:
self.char_count_label.setStyleSheet("color: #ff9800;")
else:
self.char_count_label.setStyleSheet("color: #666;")
def getNotesContent(self):
"""获取笔记内容"""
return self.text_edit.toPlainText().strip()
......@@ -5,7 +5,7 @@
根据屏幕分辨率自动调整UI尺寸
"""
from PyQt5 import QtWidgets, QtCore
from qtpy import QtWidgets, QtCore
class ResponsiveLayout:
......
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""
统一样式管理器
......@@ -40,7 +40,7 @@ import os.path as osp
import sys
try:
from PyQt5 import QtGui, QtWidgets, QtCore
from qtpy import QtGui, QtWidgets, QtCore
except ImportError as e:
raise
......@@ -78,12 +78,11 @@ class FontManager:
@staticmethod
def getDefaultFont():
"""获取默认字体"""
font = QtGui.QFont(
FontManager.DEFAULT_FONT_FAMILY,
FontManager.DEFAULT_FONT_SIZE
return FontManager.getFont(
size=FontManager.DEFAULT_FONT_SIZE,
weight=FontManager.DEFAULT_FONT_WEIGHT,
family=FontManager.DEFAULT_FONT_FAMILY
)
font.setWeight(FontManager.DEFAULT_FONT_WEIGHT)
return font
@staticmethod
def getFont(size=None, weight=None, italic=False, underline=False, family=None):
......@@ -96,7 +95,31 @@ class FontManager:
weight = FontManager.DEFAULT_FONT_WEIGHT
font = QtGui.QFont(family, size)
font.setWeight(weight)
# PySide6兼容性:将整数权重转换为QFont.Weight枚举
try:
# 尝试使用PySide6的Weight枚举
if hasattr(QtGui.QFont, 'Weight'):
if weight <= 25:
font.setWeight(QtGui.QFont.Weight.Light)
elif weight <= 50:
font.setWeight(QtGui.QFont.Weight.Normal)
elif weight <= 63:
font.setWeight(QtGui.QFont.Weight.DemiBold)
elif weight <= 75:
font.setWeight(QtGui.QFont.Weight.Bold)
else:
font.setWeight(QtGui.QFont.Weight.Black)
else:
# 回退到旧的整数方式(PyQt5)
font.setWeight(weight)
except (AttributeError, TypeError):
# 如果出现任何错误,使用默认权重
try:
font.setWeight(QtGui.QFont.Weight.Normal)
except:
pass
font.setItalic(italic)
font.setUnderline(underline)
return font
......@@ -157,8 +180,17 @@ class FontManager:
@staticmethod
def applyToApplication(app):
"""应用默认字体到整个应用程序"""
font = FontManager.getDefaultFont()
app.setFont(font)
try:
# 确保使用当前 Qt 后端创建字体对象
from qtpy import QtGui
font = QtGui.QFont(
FontManager.DEFAULT_FONT_FAMILY,
FontManager.DEFAULT_FONT_SIZE
)
font.setWeight(FontManager.DEFAULT_FONT_WEIGHT)
app.setFont(font)
except Exception as e:
print(f"字体设置失败,使用默认字体: {e}")
@staticmethod
def applyToWidgetRecursive(widget, size=None, weight=None):
......@@ -827,7 +859,7 @@ class TextButtonStyleManager:
# 计算文本宽度:字符数 * 每字符宽度 + 内边距
text_width = len(text) * cls.CHAR_WIDTH + cls.MIN_PADDING
# 返回基础宽度和计算宽度的较大值
return max(cls.BASE_WIDTH, text_width)
@classmethod
......@@ -850,7 +882,14 @@ class TextButtonStyleManager:
@classmethod
def createStandardButton(cls, text, parent=None, slot=None):
"""创建标准样式的文本按钮"""
button = QtWidgets.QPushButton(text, parent)
# 修复 PySide6 兼容性问题 - 分步创建
try:
button = QtWidgets.QPushButton(text)
if parent is not None:
button.setParent(parent)
except Exception as e:
print(f"按钮创建失败: {e}")
button = QtWidgets.QPushButton("按钮")
# 应用标准样式
cls.applyToButton(button, text)
......@@ -866,6 +905,37 @@ class TextButtonStyleManager:
"""更新按钮文本并重新调整大小"""
button.setText(new_text)
cls.applyToButton(button, new_text)
@classmethod
def applyDangerStyle(cls, button):
"""应用危险按钮样式(红色)"""
button.setStyleSheet("""
QPushButton {
background-color: #dc3545;
color: white;
border: none;
padding: 8px 16px;
border-radius: 4px;
font-weight: bold;
min-width: 100px;
}
QPushButton:hover {
background-color: #c82333;
}
QPushButton:pressed {
background-color: #bd2130;
}
QPushButton:disabled {
background-color: #6c757d;
color: #ffffff;
}
""")
@classmethod
def applyStandardStyle(cls, button):
"""应用标准按钮样式"""
# 重新应用标准样式
cls.applyToButton(button)
class BackgroundStyleManager:
"""全局背景颜色管理器"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment