Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
O
Oil_Level_Recognition_System
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
Oil_Level_Recognition_System
Commits
da148b05
Commit
da148b05
authored
Nov 28, 2025
by
yhb
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
123456
parent
7cf1afc7
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
2155 additions
and
364 deletions
+2155
-364
__main__.py
__main__.py
+9
-3
app.py
app.py
+8
-0
model_test_handler.py
handlers/modelpage/model_test_handler.py
+119
-9
model_training_handler.py
handlers/modelpage/model_training_handler.py
+703
-106
model_trainingworker_handler.py
handlers/modelpage/model_trainingworker_handler.py
+7
-0
channelpanel_handler.py
handlers/videopage/channelpanel_handler.py
+10
-1
curvepanel_handler.py
handlers/videopage/curvepanel_handler.py
+2
-2
historypanel_handler.py
handlers/videopage/historypanel_handler.py
+4
-4
modelset_page.py
widgets/modelpage/modelset_page.py
+363
-38
training_page.py
widgets/modelpage/training_page.py
+847
-188
responsive_layout.py
widgets/responsive_layout.py
+1
-1
style_manager.py
widgets/style_manager.py
+82
-12
No files found.
__main__.py
View file @
da148b05
...
@@ -29,10 +29,11 @@ print("[环境变量] OpenMP冲突已修复")
...
@@ -29,10 +29,11 @@ print("[环境变量] OpenMP冲突已修复")
from
qtpy
import
QtWidgets
from
qtpy
import
QtWidgets
from
app
import
MainWindow
# 延迟导入,避免在 QApplication 创建前创建 QWidget
# from app import MainWindow
from
database.config
import
get_config
,
get_temp_models_dir
from
database.config
import
get_config
,
get_temp_models_dir
from
widgets.style_manager
import
FontManager
#
from widgets.style_manager import FontManager
from
widgets.responsive_layout
import
ResponsiveLayout
#
from widgets.responsive_layout import ResponsiveLayout
def
setup_logging
(
level
:
str
=
"info"
):
def
setup_logging
(
level
:
str
=
"info"
):
...
@@ -270,6 +271,11 @@ def _main():
...
@@ -270,6 +271,11 @@ def _main():
app
.
setApplicationName
(
'Detection'
)
app
.
setApplicationName
(
'Detection'
)
app
.
setOrganizationName
(
'Detection'
)
app
.
setOrganizationName
(
'Detection'
)
# 在 QApplication 创建后导入可能创建 QWidget 的模块
from
app
import
MainWindow
from
widgets.style_manager
import
FontManager
from
widgets.responsive_layout
import
ResponsiveLayout
# 初始化响应式布局系统
# 初始化响应式布局系统
ResponsiveLayout
.
initialize
(
app
)
ResponsiveLayout
.
initialize
(
app
)
...
...
app.py
View file @
da148b05
...
@@ -686,6 +686,10 @@ class MainWindow(
...
@@ -686,6 +686,10 @@ class MainWindow(
from
widgets
import
ChannelPanel
,
MissionPanel
,
CurvePanel
from
widgets
import
ChannelPanel
,
MissionPanel
,
CurvePanel
# 主页面容器
# 主页面容器
# 修复:确保 QApplication 完全初始化后再创建 QWidget
app
=
QtWidgets
.
QApplication
.
instance
()
if
app
is
None
:
raise
RuntimeError
(
"QApplication 未初始化"
)
page
=
QtWidgets
.
QWidget
()
page
=
QtWidgets
.
QWidget
()
page_layout
=
QtWidgets
.
QVBoxLayout
(
page
)
page_layout
=
QtWidgets
.
QVBoxLayout
(
page
)
page_layout
.
setContentsMargins
(
0
,
0
,
0
,
0
)
page_layout
.
setContentsMargins
(
0
,
0
,
0
,
0
)
...
@@ -714,6 +718,10 @@ class MainWindow(
...
@@ -714,6 +718,10 @@ class MainWindow(
except
ImportError
:
except
ImportError
:
from
widgets
import
ChannelPanel
,
MissionPanel
from
widgets
import
ChannelPanel
,
MissionPanel
# 确保 QApplication 存在
app
=
QtWidgets
.
QApplication
.
instance
()
if
app
is
None
:
raise
RuntimeError
(
"QApplication 未初始化"
)
layout_widget
=
QtWidgets
.
QWidget
()
layout_widget
=
QtWidgets
.
QWidget
()
main_layout
=
QtWidgets
.
QHBoxLayout
(
layout_widget
)
main_layout
=
QtWidgets
.
QHBoxLayout
(
layout_widget
)
main_layout
.
setContentsMargins
(
10
,
10
,
10
,
10
)
main_layout
.
setContentsMargins
(
10
,
10
,
10
,
10
)
...
...
handlers/modelpage/model_test_handler.py
View file @
da148b05
...
@@ -283,15 +283,22 @@ class ModelTestHandler:
...
@@ -283,15 +283,22 @@ class ModelTestHandler:
def
_handleStartTestExecution
(
self
):
def
_handleStartTestExecution
(
self
):
"""执行开始测试操作 - 液位检测测试功能"""
"""执行开始测试操作 - 液位检测测试功能"""
try
:
try
:
# 🔥 清空曲线数据,准备新的测试
self
.
_clearCurve
()
# 禁用查看曲线按钮
if
hasattr
(
self
.
training_panel
,
'disableViewCurveButton'
):
self
.
training_panel
.
disableViewCurveButton
()
# 切换按钮状态为“停止测试”
# 切换按钮状态为“停止测试”
self
.
training_panel
.
setTestButtonState
(
True
)
self
.
training_panel
.
setTestButtonState
(
True
)
# 获取选择的测试模型和测试文件
# 获取选择的测试模型和测试文件
test_model_display
=
self
.
training_panel
.
test_model_combo
.
currentText
()
test_model_display
=
self
.
training_panel
.
test_model_combo
.
currentText
()
test_model_path_raw
=
self
.
training_panel
.
test_model_combo
.
currentData
()
test_model_path_raw
=
self
.
training_panel
.
test_model_combo
.
currentData
()
#
改为从QComboBox获取数据
#
从QLineEdit获取测试文件路径(浏览选择的文件)
test_file_path_raw
=
self
.
training_panel
.
test_file_input
.
currentData
()
or
""
test_file_path_raw
=
self
.
training_panel
.
test_file_input
.
text
()
.
strip
()
test_file_display
=
self
.
training_panel
.
test_file_input
.
currentText
()
test_file_display
=
os
.
path
.
basename
(
test_file_path_raw
)
if
test_file_path_raw
else
""
# 关键修复:路径规范化处理,确保相对路径转换为绝对路径
# 关键修复:路径规范化处理,确保相对路径转换为绝对路径
project_root
=
get_project_root
()
project_root
=
get_project_root
()
...
@@ -342,7 +349,7 @@ class ModelTestHandler:
...
@@ -342,7 +349,7 @@ class ModelTestHandler:
<p style="margin: 0; font-size: 12px; color: #ffffff;"><strong>解决方法:</strong></p>
<p style="margin: 0; font-size: 12px; color: #ffffff;"><strong>解决方法:</strong></p>
<ul style="margin: 5px 0; padding-left: 20px; font-size: 12px; color: #ffffff;">
<ul style="margin: 5px 0; padding-left: 20px; font-size: 12px; color: #ffffff;">
<li>请在上方下拉框中选择测试模型</li>
<li>请在上方下拉框中选择测试模型</li>
<li>请
在上方下拉框中选择测试
文件</li>
<li>请
点击"浏览..."按钮选择测试图片或视频
文件</li>
<li>确保选择的文件存在且可访问</li>
<li>确保选择的文件存在且可访问</li>
</ul>
</ul>
</div>
</div>
...
@@ -353,7 +360,7 @@ class ModelTestHandler:
...
@@ -353,7 +360,7 @@ class ModelTestHandler:
QtWidgets
.
QMessageBox
.
warning
(
QtWidgets
.
QMessageBox
.
warning
(
self
.
training_panel
,
self
.
training_panel
,
"参数缺失"
,
"参数缺失"
,
error_msg
+
"请在上方下拉框中选择测试模型和测试文件"
error_msg
+
'请在上方下拉框中选择测试模型,并点击"浏览..."按钮选择测试文件'
)
)
return
return
...
@@ -541,11 +548,24 @@ class ModelTestHandler:
...
@@ -541,11 +548,24 @@ class ModelTestHandler:
if
detection_result
:
if
detection_result
:
# 显示检测结果
# 显示检测结果
self
.
_showDetectionResult
(
detection_result
)
self
.
_showDetectionResult
(
detection_result
)
# 🔥 添加曲线数据点
if
'liquid_level_mm'
in
detection_result
:
liquid_level
=
detection_result
[
'liquid_level_mm'
]
# 图片测试只有一个数据点,帧序号为0
self
.
_addCurveDataPoint
(
0
,
liquid_level
)
# 显示曲线面板
if
hasattr
(
self
.
training_panel
,
'showCurvePanel'
):
self
.
training_panel
.
showCurvePanel
()
# 启用查看曲线按钮
if
hasattr
(
self
.
training_panel
,
'enableViewCurveButton'
):
self
.
training_panel
.
enableViewCurveButton
()
QtWidgets
.
QMessageBox
.
information
(
QtWidgets
.
QMessageBox
.
information
(
self
.
training_panel
,
self
.
training_panel
,
"测试完成"
,
"测试完成"
,
"模型测试已成功完成!"
"模型测试已成功完成!
可查看曲线分析结果。
"
)
)
# 恢复按钮状态
# 恢复按钮状态
...
@@ -585,6 +605,29 @@ class ModelTestHandler:
...
@@ -585,6 +605,29 @@ class ModelTestHandler:
self
.
_test_thread
.
wait
()
self
.
_test_thread
.
wait
()
self
.
_test_thread
=
None
self
.
_test_thread
=
None
def
_addCurveDataPoint
(
self
,
frame_index
,
height_mm
):
"""添加曲线数据点
Args:
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
try
:
if
hasattr
(
self
.
training_panel
,
'addCurvePoint'
):
self
.
training_panel
.
addCurvePoint
(
frame_index
,
height_mm
)
print
(
f
"[曲线] 添加数据点: 帧{frame_index}, 液位{height_mm:.1f}mm"
)
except
Exception
as
e
:
print
(
f
"[曲线] 添加数据点失败: {e}"
)
def
_clearCurve
(
self
):
"""清空曲线数据"""
try
:
if
hasattr
(
self
.
training_panel
,
'_clearCurve'
):
self
.
training_panel
.
_clearCurve
()
print
(
f
"[曲线] 已清空曲线"
)
except
Exception
as
e
:
print
(
f
"[曲线] 清空曲线失败: {e}"
)
def
_showDetectionResult
(
self
,
detection_result
):
def
_showDetectionResult
(
self
,
detection_result
):
"""显示检测结果"""
"""显示检测结果"""
try
:
try
:
...
@@ -592,12 +635,16 @@ class ModelTestHandler:
...
@@ -592,12 +635,16 @@ class ModelTestHandler:
# 这里可以添加结果显示逻辑
# 这里可以添加结果显示逻辑
# 例如在display_panel中显示检测结果
# 例如在display_panel中显示检测结果
if
hasattr
(
self
.
training_panel
,
'display_panel'
)
and
detection_result
:
if
hasattr
(
self
.
training_panel
,
'display_panel'
)
and
detection_result
:
# 提取液位高度
liquid_level
=
detection_result
.
get
(
'liquid_level_mm'
,
0
)
result_html
=
f
"""
result_html
=
f
"""
<div style="padding: 15px; background: #000000; border: 1px solid #28a745; border-radius: 5px; color: #ffffff;">
<div style="padding: 15px; background: #000000; border: 1px solid #28a745; border-radius: 5px; color: #ffffff;">
<h3 style="margin-top: 0; color: #28a745;">液位检测测试成功</h3>
<h3 style="margin-top: 0; color: #28a745;">液位检测测试成功</h3>
<p style="color: #ffffff;"><strong>检测结果:</strong> 已完成液位检测</p>
<p style="color: #ffffff;"><strong>检测结果:</strong> 已完成液位检测</p>
<p style="color: #ffffff;"><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
<div style="margin-top: 15px; padding: 10px; background: #1a1a1a; border-radius: 3px;">
<div style="margin-top: 15px; padding: 10px; background: #1a1a1a; border-radius: 3px;">
<p style="margin: 0; font-size: 12px; color: #ffffff;">检测结果已完成,可以在
结果面板中查看详细信息
。</p>
<p style="margin: 0; font-size: 12px; color: #ffffff;">检测结果已完成,可以在
曲线面板中查看详细分析
。</p>
</div>
</div>
</div>
</div>
"""
"""
...
@@ -1197,6 +1244,9 @@ class ModelTestHandler:
...
@@ -1197,6 +1244,9 @@ class ModelTestHandler:
success_count
=
0
success_count
=
0
fail_count
=
0
fail_count
=
0
# 🔥 清空曲线数据,准备添加新的视频检测曲线
self
.
_clearCurve
()
# 关闭进度对话框
# 关闭进度对话框
if
progress_dialog
:
if
progress_dialog
:
progress_dialog
.
setLabelText
(
"正在检测中..."
)
progress_dialog
.
setLabelText
(
"正在检测中..."
)
...
@@ -1245,6 +1295,13 @@ class ModelTestHandler:
...
@@ -1245,6 +1295,13 @@ class ModelTestHandler:
last_detection_result
=
detection_result
last_detection_result
=
detection_result
detection_count
+=
1
detection_count
+=
1
success_count
+=
1
success_count
+=
1
# 🔥 添加曲线数据点(取第一个区域的液位高度)
if
detection_result
and
len
(
detection_result
)
>
0
:
first_area_result
=
detection_result
[
0
]
if
'liquid_level_mm'
in
first_area_result
:
liquid_level
=
first_area_result
[
'liquid_level_mm'
]
self
.
_addCurveDataPoint
(
frame_index
,
liquid_level
)
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"[视频检测] 第 {frame_index} 帧检测失败: {e}"
)
print
(
f
"[视频检测] 第 {frame_index} 帧检测失败: {e}"
)
fail_count
+=
1
fail_count
+=
1
...
@@ -1317,6 +1374,15 @@ class ModelTestHandler:
...
@@ -1317,6 +1374,15 @@ class ModelTestHandler:
if
not
self
.
_detection_stopped
:
if
not
self
.
_detection_stopped
:
print
(
f
"[视频检测] 显示检测结果视频..."
)
print
(
f
"[视频检测] 显示检测结果视频..."
)
self
.
_showDetectionVideo
(
output_video_path
,
frame_index
,
detection_count
,
success_count
,
fail_count
)
self
.
_showDetectionVideo
(
output_video_path
,
frame_index
,
detection_count
,
success_count
,
fail_count
)
# 🔥 显示曲线面板
if
hasattr
(
self
.
training_panel
,
'showCurvePanel'
):
self
.
training_panel
.
showCurvePanel
()
print
(
f
"[曲线] 曲线面板已显示,共{len(self.training_panel.curve_data_x if hasattr(self.training_panel, 'curve_data_x') else [])}个数据点"
)
# 启用查看曲线按钮
if
hasattr
(
self
.
training_panel
,
'enableViewCurveButton'
):
self
.
training_panel
.
enableViewCurveButton
()
else
:
else
:
print
(
f
"[视频检测] 检测被用户停止"
)
print
(
f
"[视频检测] 检测被用户停止"
)
...
@@ -1844,6 +1910,9 @@ class ModelTestHandler:
...
@@ -1844,6 +1910,9 @@ class ModelTestHandler:
"生成文件:"
,
"生成文件:"
,
f
" 结果视频: {result_video_filename}"
,
f
" 结果视频: {result_video_filename}"
,
f
" 测试报告: {report_filename}"
,
f
" 测试报告: {report_filename}"
,
f
" JSON结果: {json_filename}"
,
f
" 曲线数据: {file_prefix}_curve.csv"
,
f
" 曲线图片: {file_prefix}_curve.png"
,
""
,
""
,
"="
*
60
,
"="
*
60
,
]
]
...
@@ -1876,7 +1945,9 @@ class ModelTestHandler:
...
@@ -1876,7 +1945,9 @@ class ModelTestHandler:
"files"
:
{
"files"
:
{
"result_video"
:
result_video_filename
,
"result_video"
:
result_video_filename
,
"report"
:
report_filename
,
"report"
:
report_filename
,
"json_result"
:
json_filename
"json_result"
:
json_filename
,
"curve_data_csv"
:
f
"{file_prefix}_curve.csv"
,
"curve_image_png"
:
f
"{file_prefix}_curve.png"
}
}
}
}
...
@@ -1885,6 +1956,23 @@ class ModelTestHandler:
...
@@ -1885,6 +1956,23 @@ class ModelTestHandler:
print
(
f
"[保存视频结果] JSON结果已保存: {json_path}"
)
print
(
f
"[保存视频结果] JSON结果已保存: {json_path}"
)
# 4. 保存曲线数据(CSV格式)和曲线图片
try
:
if
hasattr
(
self
.
training_panel
,
'saveCurveData'
)
and
hasattr
(
self
.
training_panel
,
'saveCurveImage'
):
# 保存曲线CSV数据
curve_csv_filename
=
f
"{file_prefix}_curve.csv"
curve_csv_path
=
os
.
path
.
join
(
test_results_dir
,
curve_csv_filename
)
if
self
.
training_panel
.
saveCurveData
(
curve_csv_path
):
print
(
f
"[保存视频结果] 曲线CSV数据已保存: {curve_csv_path}"
)
# 保存曲线图片
curve_image_filename
=
f
"{file_prefix}_curve.png"
curve_image_path
=
os
.
path
.
join
(
test_results_dir
,
curve_image_filename
)
if
self
.
training_panel
.
saveCurveImage
(
curve_image_path
):
print
(
f
"[保存视频结果] 曲线图片已保存: {curve_image_path}"
)
except
Exception
as
curve_error
:
print
(
f
"[保存视频结果] ⚠️ 曲线保存失败(非致命错误): {curve_error}"
)
print
(
f
"[保存视频结果] ✅ 所有测试结果已成功保存到: {test_results_dir}"
)
print
(
f
"[保存视频结果] ✅ 所有测试结果已成功保存到: {test_results_dir}"
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -1991,6 +2079,9 @@ class ModelTestHandler:
...
@@ -1991,6 +2079,9 @@ class ModelTestHandler:
f
" 原始图像: {original_filename}"
,
f
" 原始图像: {original_filename}"
,
f
" 检测结果: {result_filename}"
,
f
" 检测结果: {result_filename}"
,
f
" 测试报告: {report_filename}"
,
f
" 测试报告: {report_filename}"
,
f
" JSON结果: {json_filename}"
,
f
" 曲线数据: {file_prefix}_curve.csv"
,
f
" 曲线图片: {file_prefix}_curve.png"
,
""
,
""
,
"="
*
60
,
"="
*
60
,
]
]
...
@@ -2026,7 +2117,9 @@ class ModelTestHandler:
...
@@ -2026,7 +2117,9 @@ class ModelTestHandler:
"original_image"
:
original_filename
,
"original_image"
:
original_filename
,
"result_image"
:
result_filename
,
"result_image"
:
result_filename
,
"report"
:
report_filename
,
"report"
:
report_filename
,
"json_result"
:
json_filename
"json_result"
:
json_filename
,
"curve_data_csv"
:
f
"{file_prefix}_curve.csv"
,
"curve_image_png"
:
f
"{file_prefix}_curve.png"
}
}
}
}
...
@@ -2035,6 +2128,23 @@ class ModelTestHandler:
...
@@ -2035,6 +2128,23 @@ class ModelTestHandler:
print
(
f
"[保存图片结果] JSON结果已保存: {json_path}"
)
print
(
f
"[保存图片结果] JSON结果已保存: {json_path}"
)
# 5. 保存曲线数据(CSV格式)和曲线图片
try
:
if
hasattr
(
self
.
training_panel
,
'saveCurveData'
)
and
hasattr
(
self
.
training_panel
,
'saveCurveImage'
):
# 保存曲线CSV数据
curve_csv_filename
=
f
"{file_prefix}_curve.csv"
curve_csv_path
=
os
.
path
.
join
(
test_results_dir
,
curve_csv_filename
)
if
self
.
training_panel
.
saveCurveData
(
curve_csv_path
):
print
(
f
"[保存图片结果] 曲线CSV数据已保存: {curve_csv_path}"
)
# 保存曲线图片
curve_image_filename
=
f
"{file_prefix}_curve.png"
curve_image_path
=
os
.
path
.
join
(
test_results_dir
,
curve_image_filename
)
if
self
.
training_panel
.
saveCurveImage
(
curve_image_path
):
print
(
f
"[保存图片结果] 曲线图片已保存: {curve_image_path}"
)
except
Exception
as
curve_error
:
print
(
f
"[保存图片结果] ⚠️ 曲线保存失败(非致命错误): {curve_error}"
)
print
(
f
"[保存图片结果] ✅ 所有测试结果已成功保存到: {test_results_dir}"
)
print
(
f
"[保存图片结果] ✅ 所有测试结果已成功保存到: {test_results_dir}"
)
except
Exception
as
e
:
except
Exception
as
e
:
...
...
handlers/modelpage/model_training_handler.py
View file @
da148b05
...
@@ -165,7 +165,7 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -165,7 +165,7 @@ class ModelTrainingHandler(ModelTestHandler):
return
False
return
False
if
not
training_params
.
get
(
'save_liquid_data_path'
):
if
not
training_params
.
get
(
'save_liquid_data_path'
):
QtWidgets
.
QMessageBox
.
critical
(
self
.
main_window
,
"参数错误"
,
"未
找到可用的数据集配置文件
"
)
QtWidgets
.
QMessageBox
.
critical
(
self
.
main_window
,
"参数错误"
,
"未
选择数据集文件夹,请至少添加一个数据集文件夹
"
)
return
False
return
False
if
not
training_params
.
get
(
'exp_name'
):
if
not
training_params
.
get
(
'exp_name'
):
...
@@ -208,25 +208,44 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -208,25 +208,44 @@ class ModelTrainingHandler(ModelTestHandler):
f
"基础模型文件不存在
\n
文件路径: {base_model}
\n
请检查文件路径是否正确,或重新选择模型文件。"
)
f
"基础模型文件不存在
\n
文件路径: {base_model}
\n
请检查文件路径是否正确,或重新选择模型文件。"
)
return
False
return
False
if
not
os
.
path
.
exists
(
save_liquid_data_path
):
# 解析数据集文件夹列表(用分号分隔)
QtWidgets
.
QMessageBox
.
critical
(
self
.
main_window
,
"文件错误"
,
dataset_folders
=
[
f
.
strip
()
for
f
in
save_liquid_data_path
.
split
(
';'
)
if
f
.
strip
()]
f
"数据集配置文件不存在
\n
文件路径: {save_liquid_data_path}
\n
请检查文件路径是否正确,或重新选择数据集文件。"
)
if
not
dataset_folders
:
QtWidgets
.
QMessageBox
.
critical
(
self
.
main_window
,
"参数错误"
,
"未选择数据集文件夹,请至少添加一个数据集文件夹"
)
return
False
# 验证每个数据集文件夹是否存在
invalid_folders
=
[]
for
folder
in
dataset_folders
:
if
not
os
.
path
.
exists
(
folder
):
invalid_folders
.
append
(
folder
)
if
invalid_folders
:
QtWidgets
.
QMessageBox
.
critical
(
self
.
main_window
,
"文件夹错误"
,
f
"以下数据集文件夹不存在:
\n\n
"
+
"
\n
"
.
join
(
invalid_folders
)
+
"
\n\n
请检查文件夹路径是否正确,或重新选择数据集文件夹。"
)
return
False
return
False
# 验证数据集
配置和
内容
# 验证数据集
文件夹
内容
validation_result
,
validation_msg
=
self
.
_validate
TrainingDataWithDetails
(
save_liquid_data_path
)
validation_result
,
validation_msg
=
self
.
_validate
DatasetFolders
(
dataset_folders
)
if
not
validation_result
:
if
not
validation_result
:
QtWidgets
.
QMessageBox
.
critical
(
QtWidgets
.
QMessageBox
.
critical
(
self
.
main_window
,
self
.
main_window
,
"数据集验证失败"
,
"数据集验证失败"
,
f
"数据集验证失败:
\n\n
{validation_msg}
\n\n
请检查数据集
配置和文件
。"
f
"数据集验证失败:
\n\n
{validation_msg}
\n\n
请检查数据集
文件夹内容
。"
)
)
return
False
return
False
# 确认对话框
# 确认对话框
confirm_msg
=
f
"确定要开始升级模型吗?
\n\n
"
confirm_msg
=
f
"确定要开始升级模型吗?
\n\n
"
confirm_msg
+=
f
"基础模型: {os.path.basename(base_model)}
\n
"
confirm_msg
+=
f
"基础模型: {os.path.basename(base_model)}
\n
"
confirm_msg
+=
f
"数据集: {os.path.basename(save_liquid_data_path)}
\n
"
confirm_msg
+=
f
"数据集文件夹数量: {len(dataset_folders)}
\n
"
for
i
,
folder
in
enumerate
(
dataset_folders
,
1
):
confirm_msg
+=
f
" {i}. {os.path.basename(folder)}
\n
"
confirm_msg
+=
f
"图像尺寸: {training_params['imgsz']}
\n
"
confirm_msg
+=
f
"图像尺寸: {training_params['imgsz']}
\n
"
confirm_msg
+=
f
"训练轮数: {training_params['epochs']}
\n
"
confirm_msg
+=
f
"训练轮数: {training_params['epochs']}
\n
"
confirm_msg
+=
f
"批次大小: {training_params['batch']}
\n
"
confirm_msg
+=
f
"批次大小: {training_params['batch']}
\n
"
...
@@ -271,12 +290,22 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -271,12 +290,22 @@ class ModelTrainingHandler(ModelTestHandler):
def
_startTrainingWorker
(
self
,
training_params
):
def
_startTrainingWorker
(
self
,
training_params
):
"""启动训练工作线程"""
"""启动训练工作线程"""
try
:
try
:
# 检查是否已有训练在进行中
if
self
.
training_active
and
self
.
training_worker
:
QtWidgets
.
QMessageBox
.
warning
(
self
.
main_window
,
"提示"
,
"训练正在进行中,请先停止当前训练"
)
return
False
# 禁止自动下载yolo11模型
# 禁止自动下载yolo11模型
os
.
environ
[
'YOLO_AUTODOWNLOAD'
]
=
'0'
os
.
environ
[
'YOLO_AUTODOWNLOAD'
]
=
'0'
os
.
environ
[
'YOLO_OFFLINE'
]
=
'1'
os
.
environ
[
'YOLO_OFFLINE'
]
=
'1'
# 重置用户停止标记
# 重置用户停止标记
self
.
_is_user_stopped
=
False
self
.
_is_user_stopped
=
False
self
.
_is_stopping
=
False
# 标记训练是否正在停止中
# 如果面板处于"继续训练"模式,切换回"停止升级"模式
# 如果面板处于"继续训练"模式,切换回"停止升级"模式
if
hasattr
(
self
,
'training_panel'
):
if
hasattr
(
self
,
'training_panel'
):
...
@@ -288,6 +317,31 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -288,6 +317,31 @@ class ModelTrainingHandler(ModelTestHandler):
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'train_log_text'
):
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'train_log_text'
):
self
.
training_panel
.
train_log_text
.
clear
()
self
.
training_panel
.
train_log_text
.
clear
()
# 处理多数据集文件夹合并
dataset_folders
=
[
f
.
strip
()
for
f
in
training_params
[
'save_liquid_data_path'
]
.
split
(
';'
)
if
f
.
strip
()]
if
len
(
dataset_folders
)
>
1
:
self
.
_appendLog
(
"检测到多个数据集文件夹,正在合并...
\n
"
)
merged_data_yaml
=
self
.
_mergeMultipleDatasets
(
dataset_folders
,
training_params
[
'exp_name'
])
if
merged_data_yaml
:
training_params
[
'save_liquid_data_path'
]
=
merged_data_yaml
self
.
_appendLog
(
f
"数据集合并完成: {merged_data_yaml}
\n
"
)
else
:
QtWidgets
.
QMessageBox
.
critical
(
self
.
main_window
,
"错误"
,
"数据集合并失败"
)
return
False
elif
len
(
dataset_folders
)
==
1
:
# 单个数据集文件夹,需要创建data.yaml文件
single_folder
=
dataset_folders
[
0
]
data_yaml_path
=
self
.
_createDataYamlForSingleFolder
(
single_folder
,
training_params
[
'exp_name'
])
if
data_yaml_path
:
training_params
[
'save_liquid_data_path'
]
=
data_yaml_path
self
.
_appendLog
(
f
"已为单个数据集创建配置文件: {data_yaml_path}
\n
"
)
else
:
QtWidgets
.
QMessageBox
.
critical
(
self
.
main_window
,
"错误"
,
"创建数据集配置文件失败"
)
return
False
# 禁用笔记保存和提交按钮(训练开始时)
self
.
_disableNotesButtons
()
# 更新UI状态
# 更新UI状态
if
hasattr
(
self
,
'training_panel'
):
if
hasattr
(
self
,
'training_panel'
):
if
hasattr
(
self
.
training_panel
,
'train_status_label'
):
if
hasattr
(
self
.
training_panel
,
'train_status_label'
):
...
@@ -345,38 +399,84 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -345,38 +399,84 @@ class ModelTrainingHandler(ModelTestHandler):
return
False
return
False
def
_onStopTraining
(
self
):
def
_onStopTraining
(
self
):
"""停止训练 - 优雅停止,完成当前epoch后停止"""
"""停止训练 - 根据训练状态采用不同策略"""
# 检查是否已经在停止过程中
if
getattr
(
self
,
'_is_stopping'
,
False
):
self
.
_appendLog
(
"
\n
[提示] 训练正在停止中,请耐心等待...
\n
"
)
return
True
if
self
.
training_worker
and
self
.
training_active
:
if
self
.
training_worker
and
self
.
training_active
:
self
.
_is_user_stopped
=
True
# 标记为用户手动停止
# 检查训练是否已经真正开始
self
.
training_worker
.
stop_training
()
# 设置 is_running = False,YOLO会在epoch结束时检查
training_started
=
self
.
training_worker
.
has_training_started
()
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"用户请求停止训练
\n
"
)
self
.
_appendLog
(
"正在完成当前训练轮次...
\n
"
)
self
.
_appendLog
(
"(请勿关闭程序,等待当前epoch完成)
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
# 更新状态标签
# 设置停止标记,防止重复触发
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'train_status_label'
):
self
.
_is_stopping
=
True
self
.
training_panel
.
train_status_label
.
setText
(
"正在停止训练..."
)
self
.
_is_user_stopped
=
True
# 标记为用户手动停止
self
.
training_panel
.
train_status_label
.
setStyleSheet
(
"""
self
.
training_worker
.
stop_training
()
# 设置 is_running = False
QLabel {
color: #ffffff;
background-color: #ffc107;
border: 1px solid #ffc107;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
"""
)
FontManager
.
applyToWidget
(
self
.
training_panel
.
train_status_label
,
weight
=
FontManager
.
WEIGHT_BOLD
)
# 禁用停止按钮,防止重复点击
if
hasattr
(
self
,
'training_panel'
):
self
.
training_panel
.
stop_train_btn
.
setEnabled
(
False
)
# 不立刻终止线程,让YOLO在epoch结束时自动停止
if
not
training_started
:
# 线程会在 _onTrainingFinished 中被清理
# 训练还未真正开始(仍在初始化阶段),直接取消训练
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"训练尚未开始,正在取消训练...
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
# 更新状态标签
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'train_status_label'
):
self
.
training_panel
.
train_status_label
.
setText
(
"正在取消训练..."
)
self
.
training_panel
.
train_status_label
.
setStyleSheet
(
"""
QLabel {
color: #ffffff;
background-color: #dc3545;
border: 1px solid #dc3545;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
"""
)
FontManager
.
applyToWidget
(
self
.
training_panel
.
train_status_label
,
weight
=
FontManager
.
WEIGHT_BOLD
)
# 强制终止训练线程(因为训练还未开始,可以安全终止)
if
self
.
training_worker
:
self
.
training_worker
.
terminate
()
self
.
training_worker
.
wait
(
3000
)
# 等待最多3秒
if
self
.
training_worker
.
isRunning
():
self
.
training_worker
.
kill
()
# 强制杀死线程
# 直接调用训练完成回调,恢复UI状态
self
.
_onTrainingFinished
(
False
)
else
:
# 训练已经开始,优雅停止(完成当前epoch后停止)
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"用户请求停止训练
\n
"
)
self
.
_appendLog
(
"正在完成当前训练轮次...
\n
"
)
self
.
_appendLog
(
"(请勿关闭程序,等待当前epoch完成)
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
# 更新状态标签
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'train_status_label'
):
self
.
training_panel
.
train_status_label
.
setText
(
"正在停止训练..."
)
self
.
training_panel
.
train_status_label
.
setStyleSheet
(
"""
QLabel {
color: #ffffff;
background-color: #ffc107;
border: 1px solid #ffc107;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
"""
)
FontManager
.
applyToWidget
(
self
.
training_panel
.
train_status_label
,
weight
=
FontManager
.
WEIGHT_BOLD
)
# 禁用所有训练相关按钮,防止重复点击和冲突
if
hasattr
(
self
,
'training_panel'
):
if
hasattr
(
self
.
training_panel
,
'stop_train_btn'
):
self
.
training_panel
.
stop_train_btn
.
setEnabled
(
False
)
if
hasattr
(
self
.
training_panel
,
'start_train_btn'
):
self
.
training_panel
.
start_train_btn
.
setEnabled
(
False
)
# 不立刻终止线程,让YOLO在epoch结束时自动停止
# 线程会在 _onTrainingFinished 中被清理
return
True
return
True
else
:
else
:
...
@@ -386,7 +486,8 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -386,7 +486,8 @@ class ModelTrainingHandler(ModelTestHandler):
def
_onTrainingFinished
(
self
,
success
):
def
_onTrainingFinished
(
self
,
success
):
"""训练完成回调"""
"""训练完成回调"""
try
:
try
:
# 重置停止标记
self
.
_is_stopping
=
False
self
.
training_active
=
False
self
.
training_active
=
False
if
success
:
if
success
:
...
@@ -478,21 +579,30 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -478,21 +579,30 @@ class ModelTrainingHandler(ModelTestHandler):
self
.
_appendLog
(
" 训练完成通知
\n
"
)
self
.
_appendLog
(
" 训练完成通知
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"模型升级已完成!
\n
"
)
self
.
_appendLog
(
"模型升级已完成!
\n
"
)
self
.
_appendLog
(
"新模型已保存到detection_model目录
\n
"
)
self
.
_appendLog
(
"新模型已自动添加到模型集管理
\n
"
)
self
.
_appendLog
(
"新模型已自动添加到模型集管理
\n
"
)
self
.
_appendLog
(
"请切换到【模型集管理】页面查看新模型
\n
"
)
self
.
_appendLog
(
"请切换到【模型集管理】页面查看新模型
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
# 启用笔记保存和提交按钮(训练完成后允许继续编辑笔记)
self
.
_enableNotesButtons
()
# 使用定时器延迟显示消息框,避免阻塞训练线程的清理
# 使用定时器延迟显示消息框,避免阻塞训练线程的清理
QtCore
.
QTimer
.
singleShot
(
500
,
lambda
:
QtWidgets
.
QMessageBox
.
information
(
QtCore
.
QTimer
.
singleShot
(
500
,
lambda
:
QtWidgets
.
QMessageBox
.
information
(
self
.
main_window
,
self
.
main_window
,
"升级完成"
,
"升级完成"
,
"模型升级已完成!
\n
新模型已自动添加到模型集管理。"
"模型升级已完成!
\n
新模型已
保存到detection_model目录
\n
并
自动添加到模型集管理。"
))
))
else
:
else
:
# 检查是否为用户手动停止
# 检查是否为用户手动停止
is_user_stopped
=
getattr
(
self
,
'_is_user_stopped'
,
False
)
is_user_stopped
=
getattr
(
self
,
'_is_user_stopped'
,
False
)
if
is_user_stopped
:
# 检查训练是否已经真正开始
training_started
=
False
if
self
.
training_worker
:
training_started
=
self
.
training_worker
.
has_training_started
()
if
is_user_stopped
and
training_started
:
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"训练已暂停
\n
"
)
self
.
_appendLog
(
"训练已暂停
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
...
@@ -578,6 +688,39 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -578,6 +688,39 @@ class ModelTrainingHandler(ModelTestHandler):
# 重置标记
# 重置标记
self
.
_is_user_stopped
=
False
self
.
_is_user_stopped
=
False
self
.
_is_stopping
=
False
elif
is_user_stopped
and
not
training_started
:
# 训练被取消(未真正开始),恢复初始状态
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"训练已取消
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
# 更新状态标签
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'train_status_label'
):
self
.
training_panel
.
train_status_label
.
setText
(
"训练已取消"
)
self
.
training_panel
.
train_status_label
.
setStyleSheet
(
"""
QLabel {
color: #ffffff;
background-color: #6c757d;
border: 1px solid #6c757d;
border-radius: 4px;
padding: 10px;
min-height: 40px;
}
"""
)
FontManager
.
applyToWidget
(
self
.
training_panel
.
train_status_label
,
weight
=
FontManager
.
WEIGHT_BOLD
)
# 恢复按钮状态(允许重新开始训练)
if
hasattr
(
self
,
'training_panel'
):
if
hasattr
(
self
.
training_panel
,
'start_train_btn'
):
self
.
training_panel
.
start_train_btn
.
setEnabled
(
True
)
if
hasattr
(
self
.
training_panel
,
'stop_train_btn'
):
self
.
training_panel
.
stop_train_btn
.
setEnabled
(
False
)
self
.
training_panel
.
stop_train_btn
.
setText
(
"停止升级"
)
# 恢复原始文本
# 重置标记
self
.
_is_user_stopped
=
False
self
.
_is_stopping
=
False
else
:
else
:
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
" 升级失败
\n
"
)
self
.
_appendLog
(
" 升级失败
\n
"
)
...
@@ -609,6 +752,10 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -609,6 +752,10 @@ class ModelTrainingHandler(ModelTestHandler):
self
.
training_panel
.
stop_train_btn
.
setText
(
"停止升级"
)
self
.
training_panel
.
stop_train_btn
.
setText
(
"停止升级"
)
# 重置训练停止标记
# 重置训练停止标记
self
.
training_panel
.
_is_training_stopped
=
False
self
.
training_panel
.
_is_training_stopped
=
False
# 重置停止标记
self
.
_is_user_stopped
=
False
self
.
_is_stopping
=
False
# 如果是正常完成(非用户停止),恢复按钮状态
# 如果是正常完成(非用户停止),恢复按钮状态
if
success
and
not
self
.
_is_user_stopped
:
if
success
and
not
self
.
_is_user_stopped
:
...
@@ -891,41 +1038,19 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -891,41 +1038,19 @@ class ModelTrainingHandler(ModelTestHandler):
def
_initializeTrainingPanelDefaults
(
self
,
training_panel
):
def
_initializeTrainingPanelDefaults
(
self
,
training_panel
):
"""初始化训练面板的默认值"""
"""初始化训练面板的默认值"""
try
:
try
:
# 设置默认模型路径
# 基础模型现在通过下拉菜单从模型集管理页面加载,不需要手动设置
if
hasattr
(
training_panel
,
'base_model_edit'
):
# 下拉菜单会在页面显示时自动加载并选择默认模型
project_root
=
get_project_root
()
available_models
=
[
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'train_model'
,
'2'
,
'best.pt'
),
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'train_model'
)
]
for
model_path
in
available_models
:
if
os
.
path
.
exists
(
model_path
):
if
os
.
path
.
isfile
(
model_path
):
training_panel
.
base_model_edit
.
setText
(
model_path
)
break
elif
os
.
path
.
isdir
(
model_path
):
# 查找目录中的第一个模型文件
for
root
,
dirs
,
files
in
os
.
walk
(
model_path
):
for
file
in
files
:
if
file
.
endswith
(
'.pt'
)
or
file
.
endswith
(
'.dat'
):
full_path
=
os
.
path
.
join
(
root
,
file
)
training_panel
.
base_model_edit
.
setText
(
full_path
)
break
if
training_panel
.
base_model_edit
.
text
():
break
if
training_panel
.
base_model_edit
.
text
():
break
#
设置默认数据集路径
#
数据集现在通过文件夹列表管理,可以手动添加默认数据集文件夹
if
hasattr
(
training_panel
,
'
save_liquid_data_path_edi
t'
):
if
hasattr
(
training_panel
,
'
dataset_folders_lis
t'
):
project_root
=
get_project_root
()
project_root
=
get_project_root
()
available_datasets
=
[
default_dataset_folder
=
os
.
path
.
join
(
project_root
,
'database'
,
'dataset'
)
os
.
path
.
join
(
project_root
,
'database'
,
'dataset'
,
'data.yaml'
),
if
os
.
path
.
exists
(
default_dataset_folder
)
and
os
.
path
.
isdir
(
default_dataset_folder
):
]
# 添加默认数据集文件夹(如果列表为空)
for
dataset_path
in
available_datasets
:
if
training_panel
.
dataset_folders_list
.
count
()
==
0
:
if
os
.
path
.
exists
(
dataset_path
):
training_panel
.
dataset_folders_list
.
addItem
(
default_dataset_folder
)
training_panel
.
save_liquid_data_path_edit
.
setText
(
dataset_path
)
if
hasattr
(
training_panel
,
'_updateDatasetPath'
):
break
training_panel
.
_updateDatasetPath
()
# 设置默认模型名称
# 设置默认模型名称
if
hasattr
(
training_panel
,
'exp_name_edit'
):
if
hasattr
(
training_panel
,
'exp_name_edit'
):
...
@@ -1174,7 +1299,7 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -1174,7 +1299,7 @@ class ModelTrainingHandler(ModelTestHandler):
device_value
=
device_text
device_value
=
device_text
training_params
=
{
training_params
=
{
'base_model'
:
getattr
(
panel
,
'base_model_
edit'
,
None
)
and
panel
.
base_model_edit
.
text
()
or
''
,
'base_model'
:
getattr
(
panel
,
'base_model_
combo'
,
None
)
and
panel
.
base_model_combo
.
currentData
()
or
''
,
'save_liquid_data_path'
:
getattr
(
panel
,
'save_liquid_data_path_edit'
,
None
)
and
panel
.
save_liquid_data_path_edit
.
text
()
or
''
,
'save_liquid_data_path'
:
getattr
(
panel
,
'save_liquid_data_path_edit'
,
None
)
and
panel
.
save_liquid_data_path_edit
.
text
()
or
''
,
'imgsz'
:
getattr
(
panel
,
'imgsz_spin'
,
None
)
and
panel
.
imgsz_spin
.
value
()
or
640
,
'imgsz'
:
getattr
(
panel
,
'imgsz_spin'
,
None
)
and
panel
.
imgsz_spin
.
value
()
or
640
,
'epochs'
:
getattr
(
panel
,
'epochs_spin'
,
None
)
and
panel
.
epochs_spin
.
value
()
or
100
,
'epochs'
:
getattr
(
panel
,
'epochs_spin'
,
None
)
and
panel
.
epochs_spin
.
value
()
or
100
,
...
@@ -1241,34 +1366,39 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -1241,34 +1366,39 @@ class ModelTrainingHandler(ModelTestHandler):
# 路径有效,更新为绝对路径
# 路径有效,更新为绝对路径
training_params
[
'base_model'
]
=
base_model
training_params
[
'base_model'
]
=
base_model
# 修正数据集路径
# 修正数据集路径
(支持多文件夹,用分号分隔)
save_liquid_data_path
=
training_params
.
get
(
'save_liquid_data_path'
,
''
)
save_liquid_data_path
=
training_params
.
get
(
'save_liquid_data_path'
,
''
)
# 如果是相对路径,转换为绝对路径
if
save_liquid_data_path
:
if
save_liquid_data_path
and
not
os
.
path
.
isabs
(
save_liquid_data_path
):
# 解析多个文件夹路径
project_root
=
get_project_root
()
folders
=
[
f
.
strip
()
for
f
in
save_liquid_data_path
.
split
(
';'
)
if
f
.
strip
()]
save_liquid_data_path
=
os
.
path
.
join
(
project_root
,
save_liquid_data_path
)
fixed_folders
=
[]
if
not
save_liquid_data_path
or
not
os
.
path
.
exists
(
save_liquid_data_path
):
for
folder
in
folders
:
# 尝试使用配置文件中的默认路径
# 如果是相对路径,转换为绝对路径
if
self
.
train_config
and
'default_parameters'
in
self
.
train_config
:
if
not
os
.
path
.
isabs
(
folder
):
default_data
=
self
.
train_config
[
'default_parameters'
]
.
get
(
'dataset_path'
,
''
)
if
default_data
and
os
.
path
.
exists
(
default_data
):
training_params
[
'save_liquid_data_path'
]
=
default_data
else
:
# 使用项目中可用的数据集
project_root
=
get_project_root
()
project_root
=
get_project_root
()
available_datasets
=
[
folder
=
os
.
path
.
join
(
project_root
,
folder
)
os
.
path
.
join
(
project_root
,
'database'
,
'dataset'
,
'data_template_1.yaml'
),
os
.
path
.
join
(
project_root
,
'database'
,
'dataset'
,
'data.yaml'
),
# 只保留存在的文件夹
]
if
os
.
path
.
exists
(
folder
)
and
os
.
path
.
isdir
(
folder
):
for
dataset_path
in
available_datasets
:
fixed_folders
.
append
(
folder
)
if
os
.
path
.
exists
(
dataset_path
):
training_params
[
'save_liquid_data_path'
]
=
dataset_path
# 更新为绝对路径列表
break
if
fixed_folders
:
training_params
[
'save_liquid_data_path'
]
=
';'
.
join
(
fixed_folders
)
else
:
# 如果没有有效文件夹,尝试使用默认数据集文件夹
project_root
=
get_project_root
()
default_folder
=
os
.
path
.
join
(
project_root
,
'database'
,
'dataset'
)
if
os
.
path
.
exists
(
default_folder
)
and
os
.
path
.
isdir
(
default_folder
):
training_params
[
'save_liquid_data_path'
]
=
default_folder
else
:
else
:
# 路径有效,更新为绝对路径
# 没有指定数据集,使用默认数据集文件夹
training_params
[
'save_liquid_data_path'
]
=
save_liquid_data_path
project_root
=
get_project_root
()
default_folder
=
os
.
path
.
join
(
project_root
,
'database'
,
'dataset'
)
if
os
.
path
.
exists
(
default_folder
)
and
os
.
path
.
isdir
(
default_folder
):
training_params
[
'save_liquid_data_path'
]
=
default_folder
return
training_params
return
training_params
...
@@ -1360,6 +1490,78 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -1360,6 +1490,78 @@ class ModelTrainingHandler(ModelTestHandler):
except
Exception
as
e
:
except
Exception
as
e
:
return
False
,
f
"验证过程出错: {str(e)}"
return
False
,
f
"验证过程出错: {str(e)}"
def
_validateDatasetFolders
(
self
,
dataset_folders
):
"""
验证多个数据集文件夹
Args:
dataset_folders: 数据集文件夹路径列表
Returns:
tuple: (是否有效, 错误消息)
"""
try
:
if
not
dataset_folders
:
return
False
,
"数据集文件夹列表为空"
# 检查每个文件夹
total_train_images
=
0
total_val_images
=
0
total_labels
=
0
image_extensions
=
[
'.jpg'
,
'.jpeg'
,
'.png'
,
'.bmp'
,
'.tif'
,
'.tiff'
]
for
folder
in
dataset_folders
:
if
not
os
.
path
.
exists
(
folder
):
return
False
,
f
"文件夹不存在: {folder}"
if
not
os
.
path
.
isdir
(
folder
):
return
False
,
f
"路径不是文件夹: {folder}"
# 检查文件夹结构(YOLO格式)
# 期望结构: folder/images/train, folder/images/val, folder/labels/train, folder/labels/val
images_train_dir
=
os
.
path
.
join
(
folder
,
'images'
,
'train'
)
images_val_dir
=
os
.
path
.
join
(
folder
,
'images'
,
'val'
)
labels_train_dir
=
os
.
path
.
join
(
folder
,
'labels'
,
'train'
)
labels_val_dir
=
os
.
path
.
join
(
folder
,
'labels'
,
'val'
)
# 检查是否存在训练图片目录
if
os
.
path
.
exists
(
images_train_dir
):
train_count
=
sum
(
1
for
f
in
os
.
listdir
(
images_train_dir
)
if
any
(
f
.
lower
()
.
endswith
(
ext
)
for
ext
in
image_extensions
))
total_train_images
+=
train_count
# 检查是否存在验证图片目录
if
os
.
path
.
exists
(
images_val_dir
):
val_count
=
sum
(
1
for
f
in
os
.
listdir
(
images_val_dir
)
if
any
(
f
.
lower
()
.
endswith
(
ext
)
for
ext
in
image_extensions
))
total_val_images
+=
val_count
# 检查标签文件
if
os
.
path
.
exists
(
labels_train_dir
):
label_count
=
sum
(
1
for
f
in
os
.
listdir
(
labels_train_dir
)
if
f
.
lower
()
.
endswith
(
'.txt'
))
total_labels
+=
label_count
# 验证是否有足够的数据
if
total_train_images
==
0
and
total_val_images
==
0
:
return
False
,
f
"所有数据集文件夹中都没有找到图片文件
\n
请确保文件夹包含 images/train 或 images/val 子目录"
if
total_train_images
==
0
:
return
False
,
f
"未找到训练图片
\n
请确保至少一个文件夹包含 images/train 子目录及图片文件"
# 返回验证结果
msg
=
f
"数据集验证通过
\n
"
msg
+=
f
"总训练图片: {total_train_images} 张
\n
"
if
total_val_images
>
0
:
msg
+=
f
"总验证图片: {total_val_images} 张
\n
"
if
total_labels
>
0
:
msg
+=
f
"总标注文件: {total_labels} 个"
return
True
,
msg
except
Exception
as
e
:
return
False
,
f
"验证过程出错: {str(e)}"
def
_validateTrainingData
(
self
,
save_liquid_data_path
):
def
_validateTrainingData
(
self
,
save_liquid_data_path
):
"""验证训练数据(简化版,用于向后兼容)"""
"""验证训练数据(简化版,用于向后兼容)"""
result
,
_
=
self
.
_validateTrainingDataWithDetails
(
save_liquid_data_path
)
result
,
_
=
self
.
_validateTrainingDataWithDetails
(
save_liquid_data_path
)
...
@@ -1554,15 +1756,15 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -1554,15 +1756,15 @@ class ModelTrainingHandler(ModelTestHandler):
)
)
return
return
#
改为从QComboBox获取数据
#
从QLineEdit获取测试文件路径(浏览选择的文件)
test_file_path
=
self
.
training_panel
.
test_file_input
.
currentData
()
or
""
test_file_path
=
self
.
training_panel
.
test_file_input
.
text
()
.
strip
()
test_file_display
=
self
.
training_panel
.
test_file_input
.
currentText
()
test_file_display
=
os
.
path
.
basename
(
test_file_path
)
if
test_file_path
else
""
if
not
test_file_path
:
if
not
test_file_path
:
QtWidgets
.
QMessageBox
.
warning
(
QtWidgets
.
QMessageBox
.
warning
(
self
.
training_panel
,
self
.
training_panel
,
"提示"
,
"提示"
,
"请先选择测试文件"
'请点击"浏览..."按钮选择测试文件'
)
)
return
return
...
@@ -2183,11 +2385,21 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -2183,11 +2385,21 @@ class ModelTrainingHandler(ModelTestHandler):
'training_date'
:
self
.
_getCurrentTimestamp
()
'training_date'
:
self
.
_getCurrentTimestamp
()
}
}
# 获取训练笔记
training_notes
=
self
.
_getTrainingNotes
()
# 将模型移动到detection_model目录
detection_model_path
=
self
.
_moveModelToDetectionDir
(
best_model_path
,
model_name
,
weights_dir
,
training_notes
)
if
detection_model_path
:
# 更新模型参数中的路径
model_params
[
'path'
]
=
detection_model_path
self
.
_appendLog
(
f
"模型已移动到detection_model目录: {os.path.basename(detection_model_path)}
\n
"
)
# 保存到模型集配置
# 保存到模型集配置
self
.
_saveModelToConfig
(
model_name
,
model_params
)
self
.
_saveModelToConfig
(
model_name
,
model_params
)
self
.
_appendLog
(
f
"新模型已添加到模型集: {model_name}
\n
"
)
self
.
_appendLog
(
f
"新模型已添加到模型集: {model_name}
\n
"
)
self
.
_appendLog
(
f
" 路径: {
best_model_path
}
\n
"
)
self
.
_appendLog
(
f
" 路径: {
model_params['path']
}
\n
"
)
self
.
_appendLog
(
f
" 大小: {model_size}
\n
"
)
self
.
_appendLog
(
f
" 大小: {model_size}
\n
"
)
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -2195,8 +2407,216 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -2195,8 +2407,216 @@ class ModelTrainingHandler(ModelTestHandler):
import
traceback
import
traceback
traceback
.
print_exc
()
traceback
.
print_exc
()
def
_getTrainingNotes
(
self
):
"""获取训练页面的笔记内容"""
try
:
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'getTrainingNotes'
):
notes
=
self
.
training_panel
.
getTrainingNotes
()
if
notes
:
self
.
_appendLog
(
f
"[笔记] 获取到训练笔记,长度: {len(notes)} 字符
\n
"
)
return
notes
else
:
self
.
_appendLog
(
"[笔记] 未输入训练笔记
\n
"
)
return
""
else
:
self
.
_appendLog
(
"[笔记] 无法获取训练页面笔记接口
\n
"
)
return
""
except
Exception
as
e
:
self
.
_appendLog
(
f
"[ERROR] 获取训练笔记失败: {str(e)}
\n
"
)
return
""
def
_clearTrainingNotes
(
self
):
"""清空训练页面的笔记内容"""
try
:
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'clearTrainingNotes'
):
self
.
training_panel
.
clearTrainingNotes
()
self
.
_appendLog
(
"[笔记] 训练笔记已清空
\n
"
)
else
:
self
.
_appendLog
(
"[笔记] 无法获取训练页面笔记清空接口
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
"[ERROR] 清空训练笔记失败: {str(e)}
\n
"
)
def
_enableNotesButtons
(
self
):
"""启用训练页面的笔记保存和提交按钮"""
try
:
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'enableNotesButtons'
):
self
.
training_panel
.
enableNotesButtons
()
self
.
_appendLog
(
"[笔记] 笔记保存和提交按钮已启用
\n
"
)
else
:
self
.
_appendLog
(
"[笔记] 无法获取训练页面笔记按钮接口
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
"[ERROR] 启用笔记按钮失败: {str(e)}
\n
"
)
def
_disableNotesButtons
(
self
):
"""禁用训练页面的笔记保存和提交按钮"""
try
:
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'disableNotesButtons'
):
self
.
training_panel
.
disableNotesButtons
()
self
.
_appendLog
(
"[笔记] 笔记保存和提交按钮已禁用
\n
"
)
else
:
self
.
_appendLog
(
"[笔记] 无法获取训练页面笔记按钮接口
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
"[ERROR] 禁用笔记按钮失败: {str(e)}
\n
"
)
def
saveNotesToLatestModel
(
self
,
notes
):
"""保存笔记到最新训练的模型目录"""
try
:
if
not
notes
or
not
notes
.
strip
():
self
.
_appendLog
(
"[笔记] 笔记内容为空,无需保存
\n
"
)
return
False
# 获取最新的模型目录
latest_model_dir
=
self
.
_getLatestModelDirectory
()
if
not
latest_model_dir
:
self
.
_appendLog
(
"[ERROR] 无法找到最新的模型目录
\n
"
)
return
False
# 保存笔记文件
notes_file
=
os
.
path
.
join
(
latest_model_dir
,
'training_notes.txt'
)
try
:
with
open
(
notes_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
# 添加时间戳和更新信息
f
.
write
(
f
"训练笔记 - {os.path.basename(latest_model_dir)}
\n
"
)
f
.
write
(
f
"最后更新: {self._getCurrentTimestamp()}
\n
"
)
f
.
write
(
"="
*
50
+
"
\n\n
"
)
f
.
write
(
notes
)
self
.
_appendLog
(
f
"[笔记] 笔记已保存到: {notes_file}
\n
"
)
return
True
except
Exception
as
e
:
self
.
_appendLog
(
f
"[ERROR] 保存笔记文件失败: {str(e)}
\n
"
)
return
False
except
Exception
as
e
:
self
.
_appendLog
(
f
"[ERROR] 保存笔记到最新模型失败: {str(e)}
\n
"
)
import
traceback
traceback
.
print_exc
()
return
False
def
_getLatestModelDirectory
(
self
):
"""获取最新的detection_model目录"""
try
:
project_root
=
get_project_root
()
detection_model_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'detection_model'
)
if
not
os
.
path
.
exists
(
detection_model_dir
):
return
None
# 获取所有数字目录
digit_dirs
=
[]
for
item
in
os
.
listdir
(
detection_model_dir
):
item_path
=
os
.
path
.
join
(
detection_model_dir
,
item
)
if
os
.
path
.
isdir
(
item_path
)
and
item
.
isdigit
():
digit_dirs
.
append
(
int
(
item
))
if
not
digit_dirs
:
return
None
# 返回最大数字的目录
latest_id
=
max
(
digit_dirs
)
latest_dir
=
os
.
path
.
join
(
detection_model_dir
,
str
(
latest_id
))
self
.
_appendLog
(
f
"[笔记] 找到最新模型目录: detection_model/{latest_id}
\n
"
)
return
latest_dir
except
Exception
as
e
:
self
.
_appendLog
(
f
"[ERROR] 获取最新模型目录失败: {str(e)}
\n
"
)
return
None
def
_moveModelToDetectionDir
(
self
,
model_path
,
model_name
,
weights_dir
,
training_notes
=
""
):
"""将训练完成的模型移动到detection_model目录"""
try
:
import
shutil
from
pathlib
import
Path
self
.
_appendLog
(
f
"
\n
开始移动模型到detection_model目录...
\n
"
)
# 获取项目根目录
project_root
=
get_project_root
()
detection_model_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'detection_model'
)
# 确保detection_model目录存在
os
.
makedirs
(
detection_model_dir
,
exist_ok
=
True
)
# 获取下一个可用的数字ID
existing_dirs
=
[]
for
item
in
os
.
listdir
(
detection_model_dir
):
item_path
=
os
.
path
.
join
(
detection_model_dir
,
item
)
if
os
.
path
.
isdir
(
item_path
)
and
item
.
isdigit
():
existing_dirs
.
append
(
int
(
item
))
next_id
=
max
(
existing_dirs
)
+
1
if
existing_dirs
else
1
target_model_dir
=
os
.
path
.
join
(
detection_model_dir
,
str
(
next_id
))
self
.
_appendLog
(
f
" 目标目录: {target_model_dir}
\n
"
)
# 创建目标目录结构
os
.
makedirs
(
target_model_dir
,
exist_ok
=
True
)
target_weights_dir
=
os
.
path
.
join
(
target_model_dir
,
'weights'
)
os
.
makedirs
(
target_weights_dir
,
exist_ok
=
True
)
# 移动整个weights目录的内容
if
weights_dir
and
os
.
path
.
exists
(
weights_dir
):
self
.
_appendLog
(
f
" 复制weights目录内容...
\n
"
)
# 复制所有文件到目标weights目录
for
filename
in
os
.
listdir
(
weights_dir
):
source_file
=
os
.
path
.
join
(
weights_dir
,
filename
)
target_file
=
os
.
path
.
join
(
target_weights_dir
,
filename
)
if
os
.
path
.
isfile
(
source_file
):
shutil
.
copy2
(
source_file
,
target_file
)
self
.
_appendLog
(
f
" 复制: {filename}
\n
"
)
# 复制训练目录的其他文件(如config.yaml, results.csv等)
train_exp_dir
=
os
.
path
.
dirname
(
weights_dir
)
if
os
.
path
.
exists
(
train_exp_dir
):
for
filename
in
os
.
listdir
(
train_exp_dir
):
if
filename
!=
'weights'
:
# 跳过weights目录
source_file
=
os
.
path
.
join
(
train_exp_dir
,
filename
)
target_file
=
os
.
path
.
join
(
target_model_dir
,
filename
)
if
os
.
path
.
isfile
(
source_file
):
shutil
.
copy2
(
source_file
,
target_file
)
self
.
_appendLog
(
f
" 复制配置: {filename}
\n
"
)
# 保存训练笔记(如果有)
if
training_notes
:
notes_file
=
os
.
path
.
join
(
target_model_dir
,
'training_notes.txt'
)
try
:
with
open
(
notes_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
# 添加时间戳和模型信息
f
.
write
(
f
"训练笔记 - {model_name}
\n
"
)
f
.
write
(
f
"训练时间: {self._getCurrentTimestamp()}
\n
"
)
f
.
write
(
f
"模型ID: {next_id}
\n
"
)
f
.
write
(
"="
*
50
+
"
\n\n
"
)
f
.
write
(
training_notes
)
self
.
_appendLog
(
f
" 保存训练笔记: training_notes.txt
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
" 保存训练笔记失败: {str(e)}
\n
"
)
# 确定最终的模型文件路径
model_filename
=
os
.
path
.
basename
(
model_path
)
final_model_path
=
os
.
path
.
join
(
target_weights_dir
,
model_filename
)
if
os
.
path
.
exists
(
final_model_path
):
self
.
_appendLog
(
f
"✅ 模型已成功移动到detection_model/{next_id}/weights/{model_filename}
\n
"
)
if
training_notes
:
self
.
_appendLog
(
f
"✅ 训练笔记已保存到detection_model/{next_id}/training_notes.txt
\n
"
)
return
final_model_path
else
:
self
.
_appendLog
(
f
"❌ 模型移动失败,目标文件不存在: {final_model_path}
\n
"
)
return
None
except
Exception
as
e
:
self
.
_appendLog
(
f
"❌ [ERROR] 移动模型到detection_model失败: {str(e)}
\n
"
)
import
traceback
traceback
.
print_exc
()
return
None
def
_saveModelToConfig
(
self
,
model_name
,
model_params
):
def
_saveModelToConfig
(
self
,
model_name
,
model_params
):
"""保存模型配置文件(模型已经在
trai
n_model目录中)"""
"""保存模型配置文件(模型已经在
detectio
n_model目录中)"""
try
:
try
:
from
pathlib
import
Path
from
pathlib
import
Path
import
yaml
import
yaml
...
@@ -2205,7 +2625,7 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -2205,7 +2625,7 @@ class ModelTrainingHandler(ModelTestHandler):
self
.
_appendLog
(
f
" 模型名称: {model_name}
\n
"
)
self
.
_appendLog
(
f
" 模型名称: {model_name}
\n
"
)
self
.
_appendLog
(
f
" 模型路径: {model_params['path']}
\n
"
)
self
.
_appendLog
(
f
" 模型路径: {model_params['path']}
\n
"
)
# 获取模型所在目录(
应该已经在trai
n_model/{数字ID}/weights/中)
# 获取模型所在目录(
现在应该在detectio
n_model/{数字ID}/weights/中)
model_path
=
model_params
[
'path'
]
model_path
=
model_params
[
'path'
]
if
not
os
.
path
.
exists
(
model_path
):
if
not
os
.
path
.
exists
(
model_path
):
raise
FileNotFoundError
(
f
"模型文件不存在: {model_path}"
)
raise
FileNotFoundError
(
f
"模型文件不存在: {model_path}"
)
...
@@ -2254,7 +2674,8 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -2254,7 +2674,8 @@ class ModelTrainingHandler(ModelTestHandler):
yaml
.
dump
(
model_params
,
f
,
allow_unicode
=
True
,
default_flow_style
=
False
)
yaml
.
dump
(
model_params
,
f
,
allow_unicode
=
True
,
default_flow_style
=
False
)
# 输出总结信息
# 输出总结信息
self
.
_appendLog
(
f
"
\n
✅ 模型配置已保存: {model_dir}
\n
"
)
model_id
=
os
.
path
.
basename
(
model_dir
)
self
.
_appendLog
(
f
"
\n
✅ 模型配置已保存到detection_model/{model_id}/
\n
"
)
self
.
_appendLog
(
f
" 找到的文件:
\n
"
)
self
.
_appendLog
(
f
" 找到的文件:
\n
"
)
for
info
in
model_files_found
:
for
info
in
model_files_found
:
self
.
_appendLog
(
f
" - {info}
\n
"
)
self
.
_appendLog
(
f
" - {info}
\n
"
)
...
@@ -2262,6 +2683,7 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -2262,6 +2683,7 @@ class ModelTrainingHandler(ModelTestHandler):
for
log_file
in
log_files_found
:
for
log_file
in
log_files_found
:
self
.
_appendLog
(
f
" - 训练日志: {log_file}
\n
"
)
self
.
_appendLog
(
f
" - 训练日志: {log_file}
\n
"
)
self
.
_appendLog
(
f
" - 配置文件: config.yaml
\n
"
)
self
.
_appendLog
(
f
" - 配置文件: config.yaml
\n
"
)
self
.
_appendLog
(
f
"
\n
🎉 模型已成功保存到detection_model目录,可在模型集管理中查看!
\n
"
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
_appendLog
(
f
"❌ [ERROR] 保存模型配置失败: {str(e)}
\n
"
)
self
.
_appendLog
(
f
"❌ [ERROR] 保存模型配置失败: {str(e)}
\n
"
)
...
@@ -2652,3 +3074,178 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -2652,3 +3074,178 @@ class ModelTrainingHandler(ModelTestHandler):
return
"未知"
return
"未知"
except
:
except
:
return
"未知"
return
"未知"
def
_mergeMultipleDatasets
(
self
,
dataset_folders
,
exp_name
):
"""
合并多个数据集文件夹为统一的训练配置
Args:
dataset_folders: 数据集文件夹路径列表
exp_name: 实验名称
Returns:
str: 合并后的data.yaml文件路径,失败返回None
"""
try
:
import
shutil
import
tempfile
from
pathlib
import
Path
# 创建临时合并目录
project_root
=
get_project_root
()
temp_dataset_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'temp_datasets'
,
exp_name
)
# 如果目录已存在,先删除
if
os
.
path
.
exists
(
temp_dataset_dir
):
shutil
.
rmtree
(
temp_dataset_dir
)
# 创建合并后的目录结构
merged_images_train
=
os
.
path
.
join
(
temp_dataset_dir
,
'images'
,
'train'
)
merged_images_val
=
os
.
path
.
join
(
temp_dataset_dir
,
'images'
,
'val'
)
merged_labels_train
=
os
.
path
.
join
(
temp_dataset_dir
,
'labels'
,
'train'
)
merged_labels_val
=
os
.
path
.
join
(
temp_dataset_dir
,
'labels'
,
'val'
)
os
.
makedirs
(
merged_images_train
,
exist_ok
=
True
)
os
.
makedirs
(
merged_images_val
,
exist_ok
=
True
)
os
.
makedirs
(
merged_labels_train
,
exist_ok
=
True
)
os
.
makedirs
(
merged_labels_val
,
exist_ok
=
True
)
self
.
_appendLog
(
f
"创建合并目录: {temp_dataset_dir}
\n
"
)
# 合并所有数据集
total_train_images
=
0
total_val_images
=
0
total_train_labels
=
0
total_val_labels
=
0
for
i
,
folder
in
enumerate
(
dataset_folders
):
self
.
_appendLog
(
f
"正在合并数据集 {i+1}/{len(dataset_folders)}: {os.path.basename(folder)}
\n
"
)
# 检查源目录结构
src_images_train
=
os
.
path
.
join
(
folder
,
'images'
,
'train'
)
src_images_val
=
os
.
path
.
join
(
folder
,
'images'
,
'val'
)
src_labels_train
=
os
.
path
.
join
(
folder
,
'labels'
,
'train'
)
src_labels_val
=
os
.
path
.
join
(
folder
,
'labels'
,
'val'
)
# 复制训练图片
if
os
.
path
.
exists
(
src_images_train
):
for
filename
in
os
.
listdir
(
src_images_train
):
if
filename
.
lower
()
.
endswith
((
'.jpg'
,
'.jpeg'
,
'.png'
,
'.bmp'
,
'.tif'
,
'.tiff'
)):
src_file
=
os
.
path
.
join
(
src_images_train
,
filename
)
# 添加前缀避免文件名冲突
dst_filename
=
f
"ds{i+1}_{filename}"
dst_file
=
os
.
path
.
join
(
merged_images_train
,
dst_filename
)
shutil
.
copy2
(
src_file
,
dst_file
)
total_train_images
+=
1
# 复制验证图片
if
os
.
path
.
exists
(
src_images_val
):
for
filename
in
os
.
listdir
(
src_images_val
):
if
filename
.
lower
()
.
endswith
((
'.jpg'
,
'.jpeg'
,
'.png'
,
'.bmp'
,
'.tif'
,
'.tiff'
)):
src_file
=
os
.
path
.
join
(
src_images_val
,
filename
)
# 添加前缀避免文件名冲突
dst_filename
=
f
"ds{i+1}_{filename}"
dst_file
=
os
.
path
.
join
(
merged_images_val
,
dst_filename
)
shutil
.
copy2
(
src_file
,
dst_file
)
total_val_images
+=
1
# 复制训练标签
if
os
.
path
.
exists
(
src_labels_train
):
for
filename
in
os
.
listdir
(
src_labels_train
):
if
filename
.
lower
()
.
endswith
(
'.txt'
):
src_file
=
os
.
path
.
join
(
src_labels_train
,
filename
)
# 添加前缀避免文件名冲突,保持与图片文件名对应
dst_filename
=
f
"ds{i+1}_{filename}"
dst_file
=
os
.
path
.
join
(
merged_labels_train
,
dst_filename
)
shutil
.
copy2
(
src_file
,
dst_file
)
total_train_labels
+=
1
# 复制验证标签
if
os
.
path
.
exists
(
src_labels_val
):
for
filename
in
os
.
listdir
(
src_labels_val
):
if
filename
.
lower
()
.
endswith
(
'.txt'
):
src_file
=
os
.
path
.
join
(
src_labels_val
,
filename
)
# 添加前缀避免文件名冲突,保持与图片文件名对应
dst_filename
=
f
"ds{i+1}_{filename}"
dst_file
=
os
.
path
.
join
(
merged_labels_val
,
dst_filename
)
shutil
.
copy2
(
src_file
,
dst_file
)
total_val_labels
+=
1
self
.
_appendLog
(
f
"合并完成: 训练图片 {total_train_images} 张, 验证图片 {total_val_images} 张
\n
"
)
self
.
_appendLog
(
f
"合并完成: 训练标签 {total_train_labels} 个, 验证标签 {total_val_labels} 个
\n
"
)
# 创建data.yaml配置文件
data_yaml_path
=
os
.
path
.
join
(
temp_dataset_dir
,
'data.yaml'
)
data_config
=
{
'train'
:
os
.
path
.
join
(
temp_dataset_dir
,
'images'
,
'train'
),
'val'
:
os
.
path
.
join
(
temp_dataset_dir
,
'images'
,
'val'
),
'nc'
:
1
,
# 类别数量,液位检测通常是单类别
'names'
:
[
'liquid_level'
]
# 类别名称
}
with
open
(
data_yaml_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
yaml
.
dump
(
data_config
,
f
,
default_flow_style
=
False
,
allow_unicode
=
True
)
self
.
_appendLog
(
f
"创建配置文件: {data_yaml_path}
\n
"
)
return
data_yaml_path
except
Exception
as
e
:
self
.
_appendLog
(
f
"数据集合并失败: {str(e)}
\n
"
)
import
traceback
self
.
_appendLog
(
traceback
.
format_exc
())
return
None
def
_createDataYamlForSingleFolder
(
self
,
dataset_folder
,
exp_name
):
"""
为单个数据集文件夹创建data.yaml配置文件
Args:
dataset_folder: 数据集文件夹路径
exp_name: 实验名称
Returns:
str: data.yaml文件路径,失败返回None
"""
try
:
# 检查是否已经有data.yaml文件
existing_yaml
=
os
.
path
.
join
(
dataset_folder
,
'data.yaml'
)
if
os
.
path
.
exists
(
existing_yaml
):
self
.
_appendLog
(
f
"使用现有配置文件: {existing_yaml}
\n
"
)
return
existing_yaml
# 创建临时配置文件
project_root
=
get_project_root
()
temp_config_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'temp_configs'
)
os
.
makedirs
(
temp_config_dir
,
exist_ok
=
True
)
data_yaml_path
=
os
.
path
.
join
(
temp_config_dir
,
f
"{exp_name}_data.yaml"
)
# 检查数据集目录结构
images_train_dir
=
os
.
path
.
join
(
dataset_folder
,
'images'
,
'train'
)
images_val_dir
=
os
.
path
.
join
(
dataset_folder
,
'images'
,
'val'
)
if
not
os
.
path
.
exists
(
images_train_dir
):
self
.
_appendLog
(
f
"错误: 训练图片目录不存在: {images_train_dir}
\n
"
)
return
None
# 创建data.yaml配置
data_config
=
{
'train'
:
images_train_dir
,
'val'
:
images_val_dir
if
os
.
path
.
exists
(
images_val_dir
)
else
images_train_dir
,
'nc'
:
1
,
# 类别数量,液位检测通常是单类别
'names'
:
[
'liquid_level'
]
# 类别名称
}
with
open
(
data_yaml_path
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
yaml
.
dump
(
data_config
,
f
,
default_flow_style
=
False
,
allow_unicode
=
True
)
self
.
_appendLog
(
f
"创建配置文件: {data_yaml_path}
\n
"
)
return
data_yaml_path
except
Exception
as
e
:
self
.
_appendLog
(
f
"创建配置文件失败: {str(e)}
\n
"
)
import
traceback
self
.
_appendLog
(
traceback
.
format_exc
())
return
None
handlers/modelpage/model_trainingworker_handler.py
View file @
da148b05
...
@@ -69,6 +69,7 @@ class TrainingWorker(QThread):
...
@@ -69,6 +69,7 @@ class TrainingWorker(QThread):
super
()
.
__init__
()
super
()
.
__init__
()
self
.
training_params
=
training_params
self
.
training_params
=
training_params
self
.
is_running
=
True
self
.
is_running
=
True
self
.
training_actually_started
=
False
# 标记训练是否已经真正开始(第一个epoch开始)
self
.
train_config
=
None
self
.
train_config
=
None
# 调试信息:显示传入的训练参数
# 调试信息:显示传入的训练参数
...
@@ -522,6 +523,8 @@ class TrainingWorker(QThread):
...
@@ -522,6 +523,8 @@ class TrainingWorker(QThread):
def
on_train_start
(
trainer
):
def
on_train_start
(
trainer
):
"""训练开始回调 - 只输出到终端,不发送到UI"""
"""训练开始回调 - 只输出到终端,不发送到UI"""
# 标记训练已经真正开始
self
.
training_actually_started
=
True
# 记录开始时间
# 记录开始时间
epoch_start_time
[
0
]
=
time
.
time
()
epoch_start_time
[
0
]
=
time
.
time
()
# 不发送任何格式化消息到UI,让LogCapture直接捕获原生输出
# 不发送任何格式化消息到UI,让LogCapture直接捕获原生输出
...
@@ -1264,3 +1267,7 @@ class TrainingWorker(QThread):
...
@@ -1264,3 +1267,7 @@ class TrainingWorker(QThread):
def
stop_training
(
self
):
def
stop_training
(
self
):
"""停止训练"""
"""停止训练"""
self
.
is_running
=
False
self
.
is_running
=
False
def
has_training_started
(
self
):
"""检查训练是否已经真正开始"""
return
self
.
training_actually_started
handlers/videopage/channelpanel_handler.py
View file @
da148b05
...
@@ -1900,6 +1900,10 @@ class ChannelPanelHandler:
...
@@ -1900,6 +1900,10 @@ class ChannelPanelHandler:
def
_initConfigFileWatcher
(
self
):
def
_initConfigFileWatcher
(
self
):
"""初始化配置文件监控器"""
"""初始化配置文件监控器"""
try
:
try
:
# 临时禁用配置文件监控器以解决 QWidget 创建顺序问题
print
(
f
"[ConfigWatcher] 配置文件监控器已禁用(避免 QWidget 创建顺序问题)"
)
return
# 获取配置文件路径
# 获取配置文件路径
project_root
=
get_project_root
()
project_root
=
get_project_root
()
config_path
=
os
.
path
.
join
(
project_root
,
'database'
,
'config'
,
'default_config.yaml'
)
config_path
=
os
.
path
.
join
(
project_root
,
'database'
,
'config'
,
'default_config.yaml'
)
...
@@ -1928,7 +1932,12 @@ class ChannelPanelHandler:
...
@@ -1928,7 +1932,12 @@ class ChannelPanelHandler:
print
(
f
"🔄 [ConfigWatcher] 检测到配置文件变化: {path}"
)
print
(
f
"🔄 [ConfigWatcher] 检测到配置文件变化: {path}"
)
# 延迟一小段时间,确保文件写入完成
# 延迟一小段时间,确保文件写入完成
QtCore
.
QTimer
.
singleShot
(
100
,
self
.
_reloadChannelConfig
)
# 修复:检查 QApplication 是否存在
if
QtWidgets
.
QApplication
.
instance
()
is
not
None
:
QtCore
.
QTimer
.
singleShot
(
100
,
self
.
_reloadChannelConfig
)
else
:
# 如果没有 QApplication,直接调用重载函数
self
.
_reloadChannelConfig
()
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"[ConfigWatcher] 处理配置文件变化失败: {e}"
)
print
(
f
"[ConfigWatcher] 处理配置文件变化失败: {e}"
)
...
...
handlers/videopage/curvepanel_handler.py
View file @
da148b05
...
@@ -17,7 +17,7 @@ import csv
...
@@ -17,7 +17,7 @@ import csv
import
datetime
import
datetime
import
numpy
as
np
import
numpy
as
np
from
qtpy
import
QtWidgets
,
QtCore
from
qtpy
import
QtWidgets
,
QtCore
from
PyQt5.QtCore
import
QThread
,
pyqtSignal
from
qtpy.QtCore
import
QThread
,
Signal
as
pyqtSignal
# 导入图标工具
# 导入图标工具
try
:
try
:
...
@@ -930,7 +930,7 @@ class CurvePanelHandler:
...
@@ -930,7 +930,7 @@ class CurvePanelHandler:
QtWidgets
.
QApplication
.
processEvents
()
QtWidgets
.
QApplication
.
processEvents
()
# 🔥 延迟关闭进度条,确保用户能看到(至少显示500ms)
# 🔥 延迟关闭进度条,确保用户能看到(至少显示500ms)
from
PyQt5
.QtCore
import
QTimer
from
qtpy
.QtCore
import
QTimer
QTimer
.
singleShot
(
500
,
progress_dialog
.
close
)
QTimer
.
singleShot
(
500
,
progress_dialog
.
close
)
print
(
f
"✅ [进度条] 将在500ms后关闭"
)
print
(
f
"✅ [进度条] 将在500ms后关闭"
)
...
...
handlers/videopage/historypanel_handler.py
View file @
da148b05
...
@@ -11,7 +11,7 @@ widgets/videopage/historypanel.py (HistoryPanel)
...
@@ -11,7 +11,7 @@ widgets/videopage/historypanel.py (HistoryPanel)
-
-
"""
"""
from
PyQt5
import
QtWidgets
,
QtCore
from
qtpy
import
QtWidgets
,
QtCore
class
HistoryPanelHandler
:
class
HistoryPanelHandler
:
...
@@ -80,8 +80,8 @@ class HistoryPanelHandler:
...
@@ -80,8 +80,8 @@ class HistoryPanelHandler:
return
return
import
os
import
os
from
PyQt5
.QtMultimedia
import
QMediaPlayer
from
qtpy
.QtMultimedia
import
QMediaPlayer
from
PyQt5
import
QtWidgets
from
qtpy
import
QtWidgets
player
=
self
.
history_panel
.
_media_player
player
=
self
.
history_panel
.
_media_player
print
(
f
"[HistoryPanelHandler] 播放器: {player}"
)
print
(
f
"[HistoryPanelHandler] 播放器: {player}"
)
...
@@ -158,7 +158,7 @@ class HistoryPanelHandler:
...
@@ -158,7 +158,7 @@ class HistoryPanelHandler:
if
not
self
.
history_panel
or
not
hasattr
(
self
.
history_panel
,
'play_pause_button'
):
if
not
self
.
history_panel
or
not
hasattr
(
self
.
history_panel
,
'play_pause_button'
):
return
return
from
PyQt5
.QtMultimedia
import
QMediaPlayer
from
qtpy
.QtMultimedia
import
QMediaPlayer
from
widgets.style_manager
import
newIcon
from
widgets.style_manager
import
newIcon
button
=
self
.
history_panel
.
play_pause_button
button
=
self
.
history_panel
.
play_pause_button
...
...
widgets/modelpage/modelset_page.py
View file @
da148b05
...
@@ -699,7 +699,13 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -699,7 +699,13 @@ class ModelSetPage(QtWidgets.QWidget):
menu
.
addSeparator
()
menu
.
addSeparator
()
# 5. 删除模型(有实际功能)
# 5. 查看模型信息(新功能)
action_view_info
=
menu
.
addAction
(
"查看模型信息"
)
action_view_info
.
triggered
.
connect
(
lambda
:
self
.
viewModelInfo
(
model_name
))
menu
.
addSeparator
()
# 6. 删除模型(有实际功能)
action_delete
=
menu
.
addAction
(
"删除模型"
)
action_delete
=
menu
.
addAction
(
"删除模型"
)
action_delete
.
triggered
.
connect
(
lambda
:
self
.
deleteModel
(
model_name
))
action_delete
.
triggered
.
connect
(
lambda
:
self
.
deleteModel
(
model_name
))
...
@@ -734,6 +740,9 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -734,6 +740,9 @@ class ModelSetPage(QtWidgets.QWidget):
self
.
defaultModelChanged
.
emit
(
model_name
)
self
.
defaultModelChanged
.
emit
(
model_name
)
self
.
setDefaultRequested
.
emit
(
model_name
)
self
.
setDefaultRequested
.
emit
(
model_name
)
self
.
_updateStats
()
self
.
_updateStats
()
# 通知训练页面刷新基础模型列表
self
.
_notifyTrainingPageRefresh
()
def
updateModelParams
(
self
,
model_name
):
def
updateModelParams
(
self
,
model_name
):
"""更新模型参数显示(已移除右侧参数显示,保留方法以保持兼容性)"""
"""更新模型参数显示(已移除右侧参数显示,保留方法以保持兼容性)"""
...
@@ -1007,7 +1016,8 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -1007,7 +1016,8 @@ class ModelSetPage(QtWidgets.QWidget):
self
.
_updateModelOrder
()
self
.
_updateModelOrder
()
# 通知训练页面刷新基础模型列表
self
.
_notifyTrainingPageRefresh
()
except
Exception
as
e
:
except
Exception
as
e
:
showCritical
(
self
,
"删除失败"
,
f
"删除模型时发生错误: {e}"
)
showCritical
(
self
,
"删除失败"
,
f
"删除模型时发生错误: {e}"
)
...
@@ -1039,6 +1049,216 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -1039,6 +1049,216 @@ class ModelSetPage(QtWidgets.QWidget):
except
Exception
as
e
:
except
Exception
as
e
:
showCritical
(
self
,
"操作失败"
,
f
"添加至检测模型时发生错误: {e}"
)
showCritical
(
self
,
"操作失败"
,
f
"添加至检测模型时发生错误: {e}"
)
def
viewModelInfo
(
self
,
model_name
):
"""查看模型信息(来源和训练指标)"""
try
:
# 获取模型参数
if
model_name
not
in
self
.
_model_params
:
showWarning
(
self
,
"错误"
,
f
"未找到模型 '{model_name}' 的信息"
)
return
model_params
=
self
.
_model_params
[
model_name
]
model_path
=
model_params
.
get
(
'path'
,
''
)
if
not
model_path
or
not
os
.
path
.
exists
(
model_path
):
showWarning
(
self
,
"错误"
,
f
"模型文件不存在: {model_path}"
)
return
# 读取模型配置和训练指标
model_info
=
self
.
_readModelTrainingInfo
(
model_path
,
model_name
)
# 创建并显示信息对话框
self
.
_showModelInfoDialog
(
model_name
,
model_params
,
model_info
)
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
showCritical
(
self
,
"错误"
,
f
"查看模型信息时发生错误: {e}"
)
def
_readModelTrainingInfo
(
self
,
model_path
,
model_name
):
"""读取模型的训练配置和指标信息"""
info
=
{
'config'
:
{},
'metrics'
:
{},
'training_date'
:
None
,
'source'
:
'未知'
}
try
:
from
pathlib
import
Path
model_file
=
Path
(
model_path
)
# 1. 确定模型所在目录
if
model_file
.
is_file
():
model_dir
=
model_file
.
parent
else
:
model_dir
=
model_file
# 2. 尝试读取config.yaml(训练配置)
config_file
=
model_dir
/
'config.yaml'
if
not
config_file
.
exists
():
# 尝试上一级目录
config_file
=
model_dir
.
parent
/
'config.yaml'
if
config_file
.
exists
():
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
info
[
'config'
]
=
yaml
.
safe_load
(
f
)
or
{}
info
[
'training_date'
]
=
info
[
'config'
]
.
get
(
'training_date'
,
'未知'
)
info
[
'source'
]
=
'训练生成'
# 3. 尝试读取results.csv(训练指标)
results_file
=
model_dir
/
'results.csv'
if
not
results_file
.
exists
():
# 尝试train子目录
results_file
=
model_dir
.
parent
/
'train'
/
'results.csv'
if
results_file
.
exists
():
import
pandas
as
pd
try
:
df
=
pd
.
read_csv
(
results_file
)
if
len
(
df
)
>
0
:
# 获取最后一行(最终指标)
last_row
=
df
.
iloc
[
-
1
]
info
[
'metrics'
]
=
{
'epoch'
:
int
(
last_row
.
get
(
'epoch'
,
0
))
if
'epoch'
in
last_row
else
len
(
df
),
'train_loss'
:
float
(
last_row
.
get
(
'train/box_loss'
,
0
))
if
'train/box_loss'
in
last_row
else
None
,
'val_loss'
:
float
(
last_row
.
get
(
'val/box_loss'
,
0
))
if
'val/box_loss'
in
last_row
else
None
,
'precision'
:
float
(
last_row
.
get
(
'metrics/precision(B)'
,
0
))
if
'metrics/precision(B)'
in
last_row
else
None
,
'recall'
:
float
(
last_row
.
get
(
'metrics/recall(B)'
,
0
))
if
'metrics/recall(B)'
in
last_row
else
None
,
'mAP50'
:
float
(
last_row
.
get
(
'metrics/mAP50(B)'
,
0
))
if
'metrics/mAP50(B)'
in
last_row
else
None
,
'mAP50-95'
:
float
(
last_row
.
get
(
'metrics/mAP50-95(B)'
,
0
))
if
'metrics/mAP50-95(B)'
in
last_row
else
None
,
}
except
Exception
as
e
:
print
(
f
"读取results.csv失败: {e}"
)
# 4. 如果没有配置文件,从文件路径判断来源
if
not
info
[
'config'
]:
if
'train_model'
in
str
(
model_path
):
info
[
'source'
]
=
'本地训练'
elif
'detection_model'
in
str
(
model_path
):
info
[
'source'
]
=
'检测模型'
else
:
info
[
'source'
]
=
'导入模型'
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
print
(
f
"读取模型训练信息失败: {e}"
)
return
info
def
_showModelInfoDialog
(
self
,
model_name
,
model_params
,
model_info
):
"""显示模型信息对话框(简化版 - 只显示训练配置)"""
# 创建对话框
dialog
=
QtWidgets
.
QDialog
(
self
)
dialog
.
setWindowTitle
(
f
"模型升级配置 - {model_name}"
)
dialog
.
setMinimumWidth
(
450
)
dialog
.
setMinimumHeight
(
400
)
layout
=
QtWidgets
.
QVBoxLayout
(
dialog
)
layout
.
setSpacing
(
15
)
layout
.
setContentsMargins
(
20
,
20
,
20
,
20
)
# 标题
title_label
=
QtWidgets
.
QLabel
(
f
"<h3>{model_name}</h3>"
)
layout
.
addWidget
(
title_label
)
# 创建表单布局显示训练配置
form_widget
=
QtWidgets
.
QWidget
()
form_layout
=
QtWidgets
.
QFormLayout
(
form_widget
)
form_layout
.
setSpacing
(
10
)
form_layout
.
setContentsMargins
(
10
,
10
,
10
,
10
)
form_layout
.
setLabelAlignment
(
Qt
.
AlignRight
|
Qt
.
AlignVCenter
)
# 设置表单样式
form_widget
.
setStyleSheet
(
"""
QWidget {
background-color: #f5f5f5;
border: 1px solid #ddd;
border-radius: 4px;
}
QLabel {
font-size: 10pt;
padding: 3px;
}
"""
)
# 从配置中提取训练参数
config
=
model_info
.
get
(
'config'
,
{})
# 基础信息
info_label
=
QtWidgets
.
QLabel
(
f
"<b>模型来源:</b>{model_info['source']}"
)
form_layout
.
addRow
(
""
,
info_label
)
if
model_info
.
get
(
'training_date'
):
date_label
=
QtWidgets
.
QLabel
(
f
"<b>训练日期:</b>{model_info['training_date']}"
)
form_layout
.
addRow
(
""
,
date_label
)
# 添加分隔线
separator1
=
QtWidgets
.
QFrame
()
separator1
.
setFrameShape
(
QtWidgets
.
QFrame
.
HLine
)
separator1
.
setFrameShadow
(
QtWidgets
.
QFrame
.
Sunken
)
form_layout
.
addRow
(
separator1
)
# 训练配置参数(仿照训练页面的格式)
epochs
=
config
.
get
(
'epochs'
,
'未知'
)
batch_size
=
config
.
get
(
'batch'
,
config
.
get
(
'batch_size'
,
'未知'
))
imgsz
=
config
.
get
(
'imgsz'
,
'未知'
)
workers
=
config
.
get
(
'workers'
,
'未知'
)
device
=
config
.
get
(
'device'
,
'未知'
)
optimizer
=
config
.
get
(
'optimizer'
,
'未知'
)
form_layout
.
addRow
(
"训练轮数:"
,
QtWidgets
.
QLabel
(
f
"<b>{epochs} 轮</b>"
))
form_layout
.
addRow
(
"批次大小:"
,
QtWidgets
.
QLabel
(
f
"<b>{batch_size}</b>"
))
form_layout
.
addRow
(
"图像尺寸:"
,
QtWidgets
.
QLabel
(
f
"<b>{imgsz} px</b>"
))
form_layout
.
addRow
(
"Workers:"
,
QtWidgets
.
QLabel
(
f
"<b>{workers} 线程</b>"
))
form_layout
.
addRow
(
"训练设备:"
,
QtWidgets
.
QLabel
(
f
"<b>{device}</b>"
))
form_layout
.
addRow
(
"优化器:"
,
QtWidgets
.
QLabel
(
f
"<b>{optimizer}</b>"
))
# 如果有训练指标,显示最终性能
metrics
=
model_info
.
get
(
'metrics'
,
{})
if
metrics
:
separator2
=
QtWidgets
.
QFrame
()
separator2
.
setFrameShape
(
QtWidgets
.
QFrame
.
HLine
)
separator2
.
setFrameShadow
(
QtWidgets
.
QFrame
.
Sunken
)
form_layout
.
addRow
(
separator2
)
mAP50
=
metrics
.
get
(
'mAP50'
)
mAP5095
=
metrics
.
get
(
'mAP50-95'
)
if
mAP50
is
not
None
:
form_layout
.
addRow
(
"mAP@0.5:"
,
QtWidgets
.
QLabel
(
f
"<b style='color: #28a745;'>{mAP50:.4f}</b>"
))
if
mAP5095
is
not
None
:
form_layout
.
addRow
(
"mAP@0.5:0.95:"
,
QtWidgets
.
QLabel
(
f
"<b style='color: #28a745;'>{mAP5095:.4f}</b>"
))
layout
.
addWidget
(
form_widget
)
# 如果没有训练配置,显示提示
if
not
config
:
no_config_label
=
QtWidgets
.
QLabel
(
"<i>该模型没有训练配置信息<br>"
"可能是外部导入的模型</i>"
)
no_config_label
.
setAlignment
(
Qt
.
AlignCenter
)
no_config_label
.
setStyleSheet
(
"color: #999; padding: 20px;"
)
layout
.
addWidget
(
no_config_label
)
layout
.
addStretch
()
# 按钮
button_layout
=
QtWidgets
.
QHBoxLayout
()
button_layout
.
addStretch
()
close_btn
=
QtWidgets
.
QPushButton
(
"关闭"
)
close_btn
.
setMinimumWidth
(
100
)
close_btn
.
clicked
.
connect
(
dialog
.
accept
)
button_layout
.
addWidget
(
close_btn
)
layout
.
addLayout
(
button_layout
)
# 显示对话框
dialog
.
exec_
()
def
_moveModelToDetection
(
self
,
model_name
,
source_path
):
def
_moveModelToDetection
(
self
,
model_name
,
source_path
):
"""执行模型移动操作"""
"""执行模型移动操作"""
try
:
try
:
...
@@ -1149,10 +1369,25 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -1149,10 +1369,25 @@ class ModelSetPage(QtWidgets.QWidget):
self
.
_updateModelOrder
()
self
.
_updateModelOrder
()
# 通知训练页面刷新基础模型列表
self
.
_notifyTrainingPageRefresh
()
except
Exception
as
e
:
except
Exception
as
e
:
import
traceback
import
traceback
traceback
.
print_exc
()
traceback
.
print_exc
()
def
_notifyTrainingPageRefresh
(
self
):
"""通知训练页面刷新基础模型列表"""
try
:
# 尝试获取训练页面并刷新其基础模型列表
if
self
.
_parent
and
hasattr
(
self
.
_parent
,
'trainingPage'
):
training_page
=
self
.
_parent
.
trainingPage
if
hasattr
(
training_page
,
'_loadBaseModelOptions'
):
training_page
.
_loadBaseModelOptions
()
print
(
"[模型集管理] 已通知训练页面刷新基础模型列表"
)
except
Exception
as
e
:
print
(
f
"[模型集管理] 通知训练页面刷新失败: {e}"
)
def
_loadConfigFile
(
self
):
def
_loadConfigFile
(
self
):
"""加载配置文件"""
"""加载配置文件"""
try
:
try
:
...
@@ -1208,50 +1443,122 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -1208,50 +1443,122 @@ class ModelSetPage(QtWidgets.QWidget):
return
models
return
models
def
_scanModelDirectory
(
self
):
def
_scanModelDirectory
(
self
):
"""扫描模型目录获取所有模型文件"""
"""扫描模型目录获取所有模型文件
(优先detection_model)
"""
models
=
[]
models
=
[]
try
:
try
:
# 获取模型目录路径
# 获取模型目录路径
current_dir
=
Path
(
__file__
)
.
parent
.
parent
.
parent
current_dir
=
Path
(
__file__
)
.
parent
.
parent
.
parent
# 扫描多个模型目录
# 扫描多个模型目录
(优先detection_model)
model_dirs
=
[
model_dirs
=
[
(
current_dir
/
"database"
/
"model"
/
"detection_model"
,
"检测模型"
),
(
current_dir
/
"database"
/
"model"
/
"detection_model"
,
"检测模型"
,
True
),
# 优先级最高
(
current_dir
/
"database"
/
"model"
/
"train_model"
,
"训练模型"
),
(
current_dir
/
"database"
/
"model"
/
"train_model"
,
"训练模型"
,
False
),
(
current_dir
/
"database"
/
"model"
/
"test_model"
,
"测试模型"
)
(
current_dir
/
"database"
/
"model"
/
"test_model"
,
"测试模型"
,
False
)
]
]
for
model_dir
,
dir_type
in
model_dirs
:
for
model_dir
,
dir_type
,
is_primary
in
model_dirs
:
if
not
model_dir
.
exists
():
if
not
model_dir
.
exists
():
continue
continue
# 遍历所有子目录,并按目录名排序(确保模型1最先)
# 遍历所有子目录,按数字排序(降序,最新的在前)
sorted_subdirs
=
sorted
(
model_dir
.
iterdir
(),
key
=
lambda
x
:
x
.
name
if
x
.
is_dir
()
else
''
)
all_subdirs
=
[
d
for
d
in
model_dir
.
iterdir
()
if
d
.
is_dir
()]
digit_subdirs
=
[
d
for
d
in
all_subdirs
if
d
.
name
.
isdigit
()]
sorted_subdirs
=
sorted
(
digit_subdirs
,
key
=
lambda
x
:
int
(
x
.
name
),
reverse
=
True
)
for
subdir
in
sorted_subdirs
:
for
subdir
in
sorted_subdirs
:
if
subdir
.
is_dir
():
# 检查是否有weights子目录
# 查找 .dat 文件(优先)
weights_dir
=
subdir
/
"weights"
for
model_file
in
sorted
(
subdir
.
glob
(
"*.dat"
)):
search_dir
=
weights_dir
if
weights_dir
.
exists
()
else
subdir
models
.
append
({
'name'
:
f
"{dir_type}-{subdir.name}-{model_file.stem}"
,
# 尝试读取config.yaml获取模型名称
'path'
:
str
(
model_file
),
config_file
=
subdir
/
"config.yaml"
'subdir'
:
subdir
.
name
,
model_display_name
=
None
'source'
:
'scan'
,
if
config_file
.
exists
():
'format'
:
'dat'
,
try
:
'model_type'
:
dir_type
import
yaml
})
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
config_data
=
yaml
.
safe_load
(
f
)
if
config_data
and
'name'
in
config_data
:
model_display_name
=
config_data
[
'name'
]
except
Exception
:
pass
# 如果没有配置文件中的名称,使用默认命名
if
not
model_display_name
:
if
is_primary
:
model_display_name
=
f
"模型-{subdir.name}"
else
:
model_display_name
=
f
"{dir_type}-{subdir.name}"
# 按优先级查找模型文件:best > last > epoch1
selected_model
=
None
# 优先级1: best模型(.dat优先)
for
ext
in
[
'.dat'
,
''
]:
# 无扩展名的也考虑
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
file
.
name
.
startswith
(
'best.'
):
if
ext
==
''
and
'.'
in
file
.
name
[
5
:]:
# 有其他扩展名
continue
if
ext
!=
''
and
not
file
.
name
.
endswith
(
ext
):
continue
if
ext
==
''
or
file
.
name
.
endswith
(
ext
):
selected_model
=
file
break
if
selected_model
:
break
# 优先级2: last模型
if
not
selected_model
:
for
ext
in
[
'.dat'
,
''
]:
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
file
.
name
.
startswith
(
'last.'
):
if
ext
==
''
and
'.'
in
file
.
name
[
5
:]:
continue
if
ext
!=
''
and
not
file
.
name
.
endswith
(
ext
):
continue
if
ext
==
''
or
file
.
name
.
endswith
(
ext
):
selected_model
=
file
break
if
selected_model
:
break
# 优先级3: epoch1模型
if
not
selected_model
:
for
ext
in
[
'.dat'
,
''
]:
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
file
.
name
.
startswith
(
'epoch1.'
):
if
ext
==
''
and
'.'
in
file
.
name
[
7
:]:
continue
if
ext
!=
''
and
not
file
.
name
.
endswith
(
ext
):
continue
if
ext
==
''
or
file
.
name
.
endswith
(
ext
):
selected_model
=
file
break
if
selected_model
:
break
# 如果找到了模型文件,添加到列表
if
selected_model
:
# 获取文件格式
file_ext
=
selected_model
.
suffix
.
lstrip
(
'.'
)
if
not
file_ext
:
# 处理无扩展名的情况
if
'.'
in
selected_model
.
name
:
file_ext
=
selected_model
.
name
.
split
(
'.'
)[
-
1
]
else
:
file_ext
=
'dat'
# 默认为dat格式
# 然后查找 .pt 文件
models
.
append
({
for
model_file
in
sorted
(
subdir
.
glob
(
"*.pt"
)):
'name'
:
model_display_name
,
models
.
append
({
'path'
:
str
(
selected_model
),
'name'
:
f
"{dir_type}-{subdir.name}-{model_file.stem}"
,
'subdir'
:
subdir
.
name
,
'path'
:
str
(
model_file
)
,
'source'
:
'scan'
,
'subdir'
:
subdir
.
name
,
'format'
:
file_ext
,
'source'
:
'scan'
,
'model_type'
:
dir_type
,
'format'
:
'pt'
,
'is_primary'
:
is_primary
,
# 标记是否为主要模型目录
'model_type'
:
dir_typ
e
'file_name'
:
selected_model
.
nam
e
})
})
except
Exception
as
e
:
except
Exception
as
e
:
import
traceback
import
traceback
...
@@ -1260,29 +1567,47 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -1260,29 +1567,47 @@ class ModelSetPage(QtWidgets.QWidget):
return
models
return
models
def
_mergeModelInfo
(
self
,
channel_models
,
scanned_models
):
def
_mergeModelInfo
(
self
,
channel_models
,
scanned_models
):
"""合并模型信息,避免重复"""
"""合并模型信息,避免重复
(优先detection_model)
"""
all_models
=
[]
all_models
=
[]
seen_paths
=
set
()
seen_paths
=
set
()
# 优先添加配置文件中的通道模型
# 首先添加detection_model中的主要模型(优先级最高)
primary_models
=
[
m
for
m
in
scanned_models
if
m
.
get
(
'is_primary'
,
False
)]
for
model
in
primary_models
:
path
=
model
[
'path'
]
if
path
not
in
seen_paths
:
all_models
.
append
(
model
)
seen_paths
.
add
(
path
)
# 然后添加配置文件中的通道模型
for
model
in
channel_models
:
for
model
in
channel_models
:
path
=
model
[
'path'
]
path
=
model
[
'path'
]
if
path
not
in
seen_paths
:
if
path
not
in
seen_paths
:
all_models
.
append
(
model
)
all_models
.
append
(
model
)
seen_paths
.
add
(
path
)
seen_paths
.
add
(
path
)
# 再添加扫描到的模型(跳过已存在的)
# 最后添加其他扫描到的模型(跳过已存在的)
for
model
in
scanned_models
:
other_models
=
[
m
for
m
in
scanned_models
if
not
m
.
get
(
'is_primary'
,
False
)]
for
model
in
other_models
:
path
=
model
[
'path'
]
path
=
model
[
'path'
]
if
path
not
in
seen_paths
:
if
path
not
in
seen_paths
:
all_models
.
append
(
model
)
all_models
.
append
(
model
)
seen_paths
.
add
(
path
)
seen_paths
.
add
(
path
)
# 确保有一个默认模型:
如果没有默认模型,将第一个模型设为默认
# 确保有一个默认模型:
优先选择detection_model中的第一个模型
has_default
=
any
(
model
.
get
(
'is_default'
,
False
)
for
model
in
all_models
)
has_default
=
any
(
model
.
get
(
'is_default'
,
False
)
for
model
in
all_models
)
if
not
has_default
and
len
(
all_models
)
>
0
:
if
not
has_default
and
len
(
all_models
)
>
0
:
all_models
[
0
][
'is_default'
]
=
True
# 优先选择detection_model中的模型作为默认
pass
primary_model_found
=
False
for
model
in
all_models
:
if
model
.
get
(
'is_primary'
,
False
):
model
[
'is_default'
]
=
True
primary_model_found
=
True
break
# 如果没有detection_model中的模型,选择第一个
if
not
primary_model_found
:
all_models
[
0
][
'is_default'
]
=
True
return
all_models
return
all_models
...
...
widgets/modelpage/training_page.py
View file @
da148b05
...
@@ -11,6 +11,14 @@ from pathlib import Path
...
@@ -11,6 +11,14 @@ from pathlib import Path
from
qtpy
import
QtWidgets
,
QtCore
,
QtGui
from
qtpy
import
QtWidgets
,
QtCore
,
QtGui
from
qtpy.QtCore
import
Qt
from
qtpy.QtCore
import
Qt
# 尝试导入 PyQtGraph 用于曲线显示
try
:
import
pyqtgraph
as
pg
PYQTGRAPH_AVAILABLE
=
True
except
ImportError
:
pg
=
None
PYQTGRAPH_AVAILABLE
=
False
# 导入图标工具函数
# 导入图标工具函数
try
:
try
:
from
..icons
import
newIcon
,
newButton
from
..icons
import
newIcon
,
newButton
...
@@ -36,11 +44,11 @@ except (ImportError, ValueError):
...
@@ -36,11 +44,11 @@ except (ImportError, ValueError):
# 导入样式管理器和响应式布局
# 导入样式管理器和响应式布局
try
:
try
:
from
..style_manager
import
FontManager
,
BackgroundStyleManager
from
..style_manager
import
FontManager
,
BackgroundStyleManager
,
TextButtonStyleManager
from
..responsive_layout
import
ResponsiveLayout
,
scale_w
,
scale_h
from
..responsive_layout
import
ResponsiveLayout
,
scale_w
,
scale_h
except
(
ImportError
,
ValueError
):
except
(
ImportError
,
ValueError
):
try
:
try
:
from
widgets.style_manager
import
FontManager
,
BackgroundStyleManager
from
widgets.style_manager
import
FontManager
,
BackgroundStyleManager
,
TextButtonStyleManager
from
widgets.responsive_layout
import
ResponsiveLayout
,
scale_w
,
scale_h
from
widgets.responsive_layout
import
ResponsiveLayout
,
scale_w
,
scale_h
except
ImportError
:
except
ImportError
:
try
:
try
:
...
@@ -48,7 +56,7 @@ except (ImportError, ValueError):
...
@@ -48,7 +56,7 @@ except (ImportError, ValueError):
from
pathlib
import
Path
from
pathlib
import
Path
project_root
=
Path
(
__file__
)
.
parent
.
parent
.
parent
project_root
=
Path
(
__file__
)
.
parent
.
parent
.
parent
sys
.
path
.
insert
(
0
,
str
(
project_root
))
sys
.
path
.
insert
(
0
,
str
(
project_root
))
from
widgets.style_manager
import
FontManager
,
BackgroundStyleManager
from
widgets.style_manager
import
FontManager
,
BackgroundStyleManager
,
TextButtonStyleManager
from
widgets.responsive_layout
import
ResponsiveLayout
,
scale_w
,
scale_h
from
widgets.responsive_layout
import
ResponsiveLayout
,
scale_w
,
scale_h
except
ImportError
:
except
ImportError
:
# 如果导入失败,创建一个简单的替代类
# 如果导入失败,创建一个简单的替代类
...
@@ -96,8 +104,9 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -96,8 +104,9 @@ class TrainingPage(QtWidgets.QWidget):
# 🔥 连接模板按钮组信号
# 🔥 连接模板按钮组信号
self
.
template_button_group
.
buttonClicked
.
connect
(
self
.
_onTemplateChecked
)
self
.
template_button_group
.
buttonClicked
.
connect
(
self
.
_onTemplateChecked
)
self
.
_loadBaseModelOptions
()
# 🔥 加载基础模型选项
self
.
_loadTestModelOptions
()
# 加载测试模型选项
self
.
_loadTestModelOptions
()
# 加载测试模型选项
self
.
_loadTestFileList
()
# 🔥 加载测试文件列表
# self._loadTestFileList() # 🔥 不再需要加载测试文件列表(改用浏览方式)
def
_increaseFontSize
(
self
):
def
_increaseFontSize
(
self
):
"""增加日志字体大小"""
"""增加日志字体大小"""
...
@@ -251,6 +260,10 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -251,6 +260,10 @@ class TrainingPage(QtWidgets.QWidget):
# 将视频面板添加到显示布局中
# 将视频面板添加到显示布局中
display_layout
.
addWidget
(
self
.
video_panel
)
display_layout
.
addWidget
(
self
.
video_panel
)
# === 添加曲线显示面板 ===
self
.
_createCurvePanel
()
display_layout
.
addWidget
(
self
.
curve_panel
)
# 🔥 设置整体窗口的最小尺寸 - 使用响应式布局
# 🔥 设置整体窗口的最小尺寸 - 使用响应式布局
min_w
,
min_h
=
scale_w
(
1000
),
scale_h
(
700
)
min_w
,
min_h
=
scale_w
(
1000
),
scale_h
(
700
)
self
.
setMinimumSize
(
min_w
,
min_h
)
self
.
setMinimumSize
(
min_w
,
min_h
)
...
@@ -311,24 +324,38 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -311,24 +324,38 @@ class TrainingPage(QtWidgets.QWidget):
test_file_label
.
setStyleSheet
(
"color: #495057; font-weight: bold;"
)
test_file_label
.
setStyleSheet
(
"color: #495057; font-weight: bold;"
)
test_file_layout
.
addWidget
(
test_file_label
)
test_file_layout
.
addWidget
(
test_file_label
)
self
.
test_file_input
=
QtWidgets
.
QComboBox
()
# 测试文件路径输入框和浏览按钮
test_file_input_layout
=
QtWidgets
.
QHBoxLayout
()
test_file_input_layout
.
setSpacing
(
scale_spacing
(
5
))
# 文件路径输入框(可编辑)
self
.
test_file_input
=
QtWidgets
.
QLineEdit
()
self
.
test_file_input
.
setPlaceholderText
(
"选择测试图片或视频文件..."
)
FontManager
.
applyToWidget
(
self
.
test_file_input
)
FontManager
.
applyToWidget
(
self
.
test_file_input
)
self
.
test_file_input
.
setStyleSheet
(
"""
self
.
test_file_input
.
setStyleSheet
(
"""
Q
ComboBox
{
Q
LineEdit
{
padding: 6px 10px;
padding: 6px 10px;
border: 1px solid #ced4da;
border: 1px solid #ced4da;
border-radius: 4px;
border-radius: 4px;
background-color: white;
background-color: white;
}
}
Q
ComboBox
:focus {
Q
LineEdit
:focus {
border-color: #0078d7;
border-color: #0078d7;
outline: none;
outline: none;
}
}
QComboBox::drop-down {
border: none;
}
"""
)
"""
)
test_file_layout
.
addWidget
(
self
.
test_file_input
)
test_file_input_layout
.
addWidget
(
self
.
test_file_input
)
# 浏览按钮(使用全局样式管理器)
self
.
test_file_browse_btn
=
TextButtonStyleManager
.
createStandardButton
(
"浏览..."
,
parent
=
self
,
slot
=
self
.
_browseTestFile
)
self
.
test_file_browse_btn
.
setMinimumWidth
(
scale_w
(
60
))
test_file_input_layout
.
addWidget
(
self
.
test_file_browse_btn
)
test_file_layout
.
addLayout
(
test_file_input_layout
)
right_control_layout
.
addLayout
(
test_file_layout
)
right_control_layout
.
addLayout
(
test_file_layout
)
# 添加垂直间距
# 添加垂直间距
...
@@ -337,14 +364,32 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -337,14 +364,32 @@ class TrainingPage(QtWidgets.QWidget):
test_button_layout
=
QtWidgets
.
QHBoxLayout
()
test_button_layout
=
QtWidgets
.
QHBoxLayout
()
ResponsiveLayout
.
apply_to_layout
(
test_button_layout
,
base_spacing
=
10
,
base_margins
=
0
)
ResponsiveLayout
.
apply_to_layout
(
test_button_layout
,
base_spacing
=
10
,
base_margins
=
0
)
self
.
start_annotation_btn
=
QtWidgets
.
QPushButton
(
"开始标注"
)
# 使用全局样式管理器创建测试按钮
self
.
start_annotation_btn
=
TextButtonStyleManager
.
createStandardButton
(
"开始标注"
,
parent
=
self
)
self
.
start_annotation_btn
.
setMinimumWidth
(
scale_w
(
80
))
self
.
start_annotation_btn
.
setMinimumWidth
(
scale_w
(
80
))
test_button_layout
.
addWidget
(
self
.
start_annotation_btn
)
test_button_layout
.
addWidget
(
self
.
start_annotation_btn
)
self
.
start_test_btn
=
QtWidgets
.
QPushButton
(
"开始测试"
)
self
.
start_test_btn
=
TextButtonStyleManager
.
createStandardButton
(
"开始测试"
,
parent
=
self
)
self
.
start_test_btn
.
setMinimumWidth
(
scale_w
(
80
))
self
.
start_test_btn
.
setMinimumWidth
(
scale_w
(
80
))
test_button_layout
.
addWidget
(
self
.
start_test_btn
)
test_button_layout
.
addWidget
(
self
.
start_test_btn
)
# 查看曲线按钮(使用全局样式管理器)
self
.
view_curve_btn
=
TextButtonStyleManager
.
createStandardButton
(
"查看曲线"
,
parent
=
self
,
slot
=
self
.
_onViewCurveClicked
)
self
.
view_curve_btn
.
setMinimumWidth
(
scale_w
(
80
))
self
.
view_curve_btn
.
setEnabled
(
False
)
# 初始状态禁用
self
.
view_curve_btn
.
setToolTip
(
"测试完成后可查看曲线结果"
)
test_button_layout
.
addWidget
(
self
.
view_curve_btn
)
# 将按钮布局添加到主布局
# 将按钮布局添加到主布局
right_control_layout
.
addLayout
(
test_button_layout
)
right_control_layout
.
addLayout
(
test_button_layout
)
...
@@ -389,29 +434,37 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -389,29 +434,37 @@ class TrainingPage(QtWidgets.QWidget):
FontManager
.
applyToWidget
(
self
.
auto_scroll_check
)
FontManager
.
applyToWidget
(
self
.
auto_scroll_check
)
log_header
.
addWidget
(
self
.
auto_scroll_check
)
log_header
.
addWidget
(
self
.
auto_scroll_check
)
# 清空日志按钮(使用系统默认样式 + 响应式布局)
# 清空日志按钮(使用全局样式管理器)
clear_log_btn
=
QtWidgets
.
QPushButton
(
"清空日志"
)
clear_log_btn
=
TextButtonStyleManager
.
createStandardButton
(
"清空日志"
,
parent
=
self
,
slot
=
self
.
clearLog
)
clear_log_btn
.
setMinimumWidth
(
scale_w
(
60
))
clear_log_btn
.
setMinimumWidth
(
scale_w
(
60
))
# 🔥 添加字体大小调整按钮
# 🔥 添加字体大小调整按钮
font_size_label
=
QtWidgets
.
QLabel
(
"字体:"
)
font_size_label
=
QtWidgets
.
QLabel
(
"字体:"
)
# 使用系统默认样式
FontManager
.
applyToWidget
(
font_size_label
)
FontManager
.
applyToWidget
(
font_size_label
)
log_header
.
addWidget
(
font_size_label
)
log_header
.
addWidget
(
font_size_label
)
self
.
font_decrease_btn
=
QtWidgets
.
QPushButton
(
"-"
)
# 字体调整按钮(使用全局样式管理器)
self
.
font_decrease_btn
=
TextButtonStyleManager
.
createStandardButton
(
"-"
,
parent
=
self
,
slot
=
self
.
_decreaseFontSize
)
btn_size
=
scale_w
(
20
)
# 响应式按钮尺寸
btn_size
=
scale_w
(
20
)
# 响应式按钮尺寸
self
.
font_decrease_btn
.
setFixedSize
(
btn_size
,
btn_size
)
self
.
font_decrease_btn
.
setFixedSize
(
btn_size
,
btn_size
)
# 使用系统默认样式
self
.
font_decrease_btn
.
clicked
.
connect
(
self
.
_decreaseFontSize
)
log_header
.
addWidget
(
self
.
font_decrease_btn
)
log_header
.
addWidget
(
self
.
font_decrease_btn
)
self
.
font_increase_btn
=
QtWidgets
.
QPushButton
(
"+"
)
self
.
font_increase_btn
=
TextButtonStyleManager
.
createStandardButton
(
"+"
,
parent
=
self
,
slot
=
self
.
_increaseFontSize
)
self
.
font_increase_btn
.
setFixedSize
(
btn_size
,
btn_size
)
self
.
font_increase_btn
.
setFixedSize
(
btn_size
,
btn_size
)
# 使用系统默认样式
self
.
font_increase_btn
.
clicked
.
connect
(
self
.
_increaseFontSize
)
log_header
.
addWidget
(
self
.
font_increase_btn
)
log_header
.
addWidget
(
self
.
font_increase_btn
)
clear_log_btn
.
clicked
.
connect
(
self
.
clearLog
)
log_header
.
addWidget
(
clear_log_btn
)
log_header
.
addWidget
(
clear_log_btn
)
right_layout
.
addLayout
(
log_header
)
right_layout
.
addLayout
(
log_header
)
...
@@ -467,6 +520,199 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -467,6 +520,199 @@ class TrainingPage(QtWidgets.QWidget):
pass
pass
def
_createCurvePanel
(
self
):
"""创建曲线显示面板"""
self
.
curve_panel
=
QtWidgets
.
QWidget
()
curve_layout
=
QtWidgets
.
QVBoxLayout
(
self
.
curve_panel
)
curve_layout
.
setContentsMargins
(
5
,
5
,
5
,
5
)
curve_layout
.
setSpacing
(
5
)
# 曲线面板标题
curve_title_layout
=
QtWidgets
.
QHBoxLayout
()
curve_title
=
QtWidgets
.
QLabel
(
"测试结果曲线"
)
curve_title
.
setStyleSheet
(
"color: #495057; font-weight: bold; font-size: 12pt;"
)
FontManager
.
applyToWidget
(
curve_title
,
weight
=
FontManager
.
WEIGHT_BOLD
)
curve_title_layout
.
addWidget
(
curve_title
)
# 添加清空曲线按钮(使用全局样式管理器)
self
.
clear_curve_btn
=
TextButtonStyleManager
.
createStandardButton
(
"清空曲线"
,
parent
=
self
,
slot
=
self
.
_clearCurve
)
self
.
clear_curve_btn
.
setMinimumWidth
(
scale_w
(
80
))
curve_title_layout
.
addStretch
()
curve_title_layout
.
addWidget
(
self
.
clear_curve_btn
)
curve_layout
.
addLayout
(
curve_title_layout
)
# 检查 PyQtGraph 是否可用
if
PYQTGRAPH_AVAILABLE
:
# 创建 PyQtGraph 绘图控件
self
.
curve_plot_widget
=
pg
.
PlotWidget
()
self
.
curve_plot_widget
.
setBackground
(
'#ffffff'
)
self
.
curve_plot_widget
.
showGrid
(
x
=
True
,
y
=
True
,
alpha
=
0.3
)
# 设置坐标轴标签
self
.
curve_plot_widget
.
setLabel
(
'left'
,
'液位高度'
,
units
=
'mm'
)
self
.
curve_plot_widget
.
setLabel
(
'bottom'
,
'帧序号'
)
self
.
curve_plot_widget
.
setTitle
(
'液位检测曲线'
,
color
=
'#495057'
,
size
=
'12pt'
)
# 添加图例
self
.
curve_plot_widget
.
addLegend
()
# 存储曲线数据
self
.
curve_data_x
=
[]
# X轴数据(帧序号)
self
.
curve_data_y
=
[]
# Y轴数据(液位高度)
self
.
curve_line
=
None
# 曲线对象
curve_layout
.
addWidget
(
self
.
curve_plot_widget
)
else
:
# PyQtGraph 不可用,显示提示信息
placeholder
=
QtWidgets
.
QLabel
(
"曲线显示功能需要 PyQtGraph 库
\n\n
"
"请安装: pip install pyqtgraph"
)
placeholder
.
setAlignment
(
Qt
.
AlignCenter
)
placeholder
.
setStyleSheet
(
"color: #999; font-size: 11pt; padding: 50px;"
)
FontManager
.
applyToWidget
(
placeholder
)
curve_layout
.
addWidget
(
placeholder
)
def
_clearCurve
(
self
):
"""清空曲线数据并隐藏曲线面板"""
if
PYQTGRAPH_AVAILABLE
and
hasattr
(
self
,
'curve_plot_widget'
):
# 清空数据
self
.
curve_data_x
=
[]
self
.
curve_data_y
=
[]
# 清空曲线
if
self
.
curve_line
:
self
.
curve_plot_widget
.
removeItem
(
self
.
curve_line
)
self
.
curve_line
=
None
print
(
"[曲线] 已清空曲线数据"
)
# 自动隐藏曲线面板,返回到初始显示状态
self
.
hideCurvePanel
()
print
(
"[曲线] 曲线面板已隐藏,返回初始显示状态"
)
def
addCurvePoint
(
self
,
frame_index
,
height_mm
):
"""添加曲线数据点
Args:
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
if
not
PYQTGRAPH_AVAILABLE
or
not
hasattr
(
self
,
'curve_plot_widget'
):
return
try
:
# 添加数据点
self
.
curve_data_x
.
append
(
frame_index
)
self
.
curve_data_y
.
append
(
height_mm
)
# 如果曲线不存在,创建曲线
if
self
.
curve_line
is
None
:
self
.
curve_line
=
self
.
curve_plot_widget
.
plot
(
self
.
curve_data_x
,
self
.
curve_data_y
,
pen
=
pg
.
mkPen
(
color
=
'#1f77b4'
,
width
=
2
),
name
=
'液位高度'
)
else
:
# 更新曲线数据
self
.
curve_line
.
setData
(
self
.
curve_data_x
,
self
.
curve_data_y
)
except
Exception
as
e
:
print
(
f
"[曲线] 添加数据点失败: {e}"
)
def
showCurvePanel
(
self
):
"""显示曲线面板"""
if
hasattr
(
self
,
'display_layout'
):
# 切换到曲线面板(索引3:hint, display_panel, video_panel, curve_panel)
self
.
display_layout
.
setCurrentIndex
(
3
)
def
hideCurvePanel
(
self
):
"""隐藏曲线面板,返回到显示面板"""
if
hasattr
(
self
,
'display_layout'
):
self
.
display_layout
.
setCurrentIndex
(
1
)
# 显示 display_panel
def
saveCurveData
(
self
,
csv_path
):
"""保存曲线数据为CSV文件
Args:
csv_path: CSV文件保存路径
Returns:
bool: 是否成功保存
"""
if
not
PYQTGRAPH_AVAILABLE
or
not
hasattr
(
self
,
'curve_data_x'
):
return
False
try
:
import
csv
# 检查是否有数据
if
len
(
self
.
curve_data_x
)
==
0
:
print
(
"[曲线保存] 没有曲线数据可保存"
)
return
False
# 写入CSV文件
with
open
(
csv_path
,
'w'
,
newline
=
''
,
encoding
=
'utf-8'
)
as
f
:
writer
=
csv
.
writer
(
f
)
# 写入表头
writer
.
writerow
([
'帧序号'
,
'液位高度(mm)'
])
# 写入数据
for
x
,
y
in
zip
(
self
.
curve_data_x
,
self
.
curve_data_y
):
writer
.
writerow
([
x
,
y
])
print
(
f
"[曲线保存] CSV数据已保存: {csv_path}"
)
print
(
f
"[曲线保存] 共保存 {len(self.curve_data_x)} 个数据点"
)
return
True
except
Exception
as
e
:
print
(
f
"[曲线保存] 保存CSV失败: {e}"
)
import
traceback
traceback
.
print_exc
()
return
False
def
saveCurveImage
(
self
,
image_path
):
"""保存曲线为图片文件
Args:
image_path: 图片文件保存路径(支持 .png, .jpg, .svg等格式)
Returns:
bool: 是否成功保存
"""
if
not
PYQTGRAPH_AVAILABLE
or
not
hasattr
(
self
,
'curve_plot_widget'
):
return
False
try
:
# 检查是否有数据
if
len
(
self
.
curve_data_x
)
==
0
:
print
(
"[曲线保存] 没有曲线数据可保存"
)
return
False
# 使用PyQtGraph的导出功能
exporter
=
pg
.
exporters
.
ImageExporter
(
self
.
curve_plot_widget
.
plotItem
)
# 设置导出参数
exporter
.
parameters
()[
'width'
]
=
1200
# 设置宽度
exporter
.
parameters
()[
'height'
]
=
600
# 设置高度
# 导出图片
exporter
.
export
(
image_path
)
print
(
f
"[曲线保存] 曲线图片已保存: {image_path}"
)
return
True
except
Exception
as
e
:
print
(
f
"[曲线保存] 保存图片失败: {e}"
)
import
traceback
traceback
.
print_exc
()
return
False
def
_createParametersGroup
(
self
):
def
_createParametersGroup
(
self
):
"""创建参数配置组"""
"""创建参数配置组"""
group
=
QtWidgets
.
QGroupBox
(
""
)
group
=
QtWidgets
.
QGroupBox
(
""
)
...
@@ -502,38 +748,53 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -502,38 +748,53 @@ class TrainingPage(QtWidgets.QWidget):
# 创建表单布局
# 创建表单布局
layout
=
QtWidgets
.
QFormLayout
()
layout
=
QtWidgets
.
QFormLayout
()
from
..responsive_layout
import
scale_spacing
from
..responsive_layout
import
scale_spacing
layout
.
setSpacing
(
scale_spacing
(
8
)
)
layout
.
setSpacing
(
8
)
layout
.
setContentsMargins
(
0
,
0
,
0
,
0
)
layout
.
setContentsMargins
(
0
,
0
,
0
,
0
)
layout
.
setLabelAlignment
(
Qt
.
AlignRight
|
Qt
.
AlignVCenter
)
layout
.
setLabelAlignment
(
Qt
.
AlignRight
|
Qt
.
AlignVCenter
)
# 🔥 删除CSS样式表,改为纯控件方式,由字体管理器统一管理
# 🔥 删除CSS样式表,改为纯控件方式,由字体管理器统一管理
# 基础模型路径
# 基础模型选择(下拉菜单)
model_layout
=
QtWidgets
.
QHBoxLayout
()
self
.
base_model_combo
=
QtWidgets
.
QComboBox
()
from
..responsive_layout
import
scale_spacing
self
.
base_model_combo
.
setPlaceholderText
(
"请选择基础模型"
)
model_layout
.
setSpacing
(
scale_spacing
(
6
))
self
.
base_model_combo
.
setSizePolicy
(
QtWidgets
.
QSizePolicy
.
Expanding
,
QtWidgets
.
QSizePolicy
.
Fixed
)
self
.
base_model_edit
=
QtWidgets
.
QLineEdit
()
layout
.
addRow
(
"基础模型:"
,
self
.
base_model_combo
)
self
.
base_model_edit
.
setPlaceholderText
(
"选择基础模型文件 (.pt)"
)
self
.
browse_model_btn
=
QtWidgets
.
QPushButton
(
"浏览..."
)
# 数据集文件夹选择(统一格式,支持多个文件夹)
self
.
browse_model_btn
.
setFixedWidth
(
scale_w
(
80
))
dataset_layout
=
QtWidgets
.
QHBoxLayout
()
self
.
browse_model_btn
.
clicked
.
connect
(
self
.
_browseModel
)
dataset_layout
.
setSpacing
(
8
)
model_layout
.
addWidget
(
self
.
base_model_edit
,
1
)
dataset_layout
.
setContentsMargins
(
0
,
0
,
0
,
0
)
model_layout
.
addWidget
(
self
.
browse_model_btn
)
layout
.
addRow
(
"基础模型:"
,
model_layout
)
# 数据集路径显示文本框(可编辑,支持多个路径用分号分隔)
self
.
dataset_paths_edit
=
QtWidgets
.
QLineEdit
()
# 数据集路径 (字段名必须是 save_liquid_data_path_edit 以匹配训练处理器)
self
.
dataset_paths_edit
.
setPlaceholderText
(
""
)
data_layout
=
QtWidgets
.
QHBoxLayout
()
self
.
dataset_paths_edit
.
setSizePolicy
(
QtWidgets
.
QSizePolicy
.
Expanding
,
QtWidgets
.
QSizePolicy
.
Fixed
)
from
..responsive_layout
import
scale_spacing
# 移除自定义样式,使用全局字体管理器统一管理
data_layout
.
setSpacing
(
scale_spacing
(
6
))
self
.
dataset_paths_edit
.
setStyleSheet
(
""
)
# 连接文本变化信号,实时更新内部数据集列表
self
.
dataset_paths_edit
.
textChanged
.
connect
(
self
.
_onDatasetPathsChanged
)
dataset_layout
.
addWidget
(
self
.
dataset_paths_edit
)
# 浏览按钮(使用全局样式管理器)
self
.
btn_browse_datasets
=
TextButtonStyleManager
.
createStandardButton
(
"浏览..."
,
parent
=
self
,
slot
=
self
.
_onBrowseDatasets
)
self
.
btn_browse_datasets
.
setFixedWidth
(
100
)
self
.
btn_browse_datasets
.
setSizePolicy
(
QtWidgets
.
QSizePolicy
.
Fixed
,
QtWidgets
.
QSizePolicy
.
Fixed
)
dataset_layout
.
addWidget
(
self
.
btn_browse_datasets
)
layout
.
addRow
(
"数据集:"
,
dataset_layout
)
# 保持内部数据集文件夹列表(隐藏,用于兼容现有逻辑)
self
.
dataset_folders_list
=
QtWidgets
.
QListWidget
()
self
.
dataset_folders_list
.
setVisible
(
False
)
# 隐藏,仅用于内部数据管理
# 保留旧的字段名以保持向后兼容(用于获取数据集路径)
# 现在它将存储用分号分隔的多个文件夹路径
self
.
save_liquid_data_path_edit
=
QtWidgets
.
QLineEdit
()
self
.
save_liquid_data_path_edit
=
QtWidgets
.
QLineEdit
()
self
.
save_liquid_data_path_edit
.
setPlaceholderText
(
"选择数据集配置文件 (data.yaml)"
)
self
.
save_liquid_data_path_edit
.
setVisible
(
False
)
# 隐藏,仅用于数据传递
self
.
save_liquid_data_path_edit
.
setText
(
"database/dataset/data.yaml"
)
# 默认路径
self
.
browse_data_btn
=
QtWidgets
.
QPushButton
(
"浏览..."
)
self
.
browse_data_btn
.
setFixedWidth
(
scale_w
(
80
))
self
.
browse_data_btn
.
clicked
.
connect
(
self
.
_browseDataset
)
data_layout
.
addWidget
(
self
.
save_liquid_data_path_edit
,
1
)
data_layout
.
addWidget
(
self
.
browse_data_btn
)
layout
.
addRow
(
"数据集:"
,
data_layout
)
# 实验名称
# 实验名称
self
.
exp_name_edit
=
QtWidgets
.
QLineEdit
()
self
.
exp_name_edit
=
QtWidgets
.
QLineEdit
()
...
@@ -586,6 +847,18 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -586,6 +847,18 @@ class TrainingPage(QtWidgets.QWidget):
self
.
optimizer_combo
.
addItems
([
"SGD"
,
"Adam"
,
"AdamW"
])
self
.
optimizer_combo
.
addItems
([
"SGD"
,
"Adam"
,
"AdamW"
])
layout
.
addRow
(
"优化器:"
,
self
.
optimizer_combo
)
layout
.
addRow
(
"优化器:"
,
self
.
optimizer_combo
)
# 训练笔记按钮(使用全局样式管理器)
self
.
training_notes_btn
=
TextButtonStyleManager
.
createStandardButton
(
"训练笔记"
,
parent
=
self
,
slot
=
self
.
_openNotesDialog
)
self
.
training_notes_btn
.
setToolTip
(
"点击打开训练笔记编辑窗口"
)
layout
.
addRow
(
"训练笔记:"
,
self
.
training_notes_btn
)
# 内部存储笔记内容的变量
self
.
_training_notes_content
=
""
# 高级选项分隔
# 高级选项分隔
separator
=
QtWidgets
.
QLabel
()
separator
=
QtWidgets
.
QLabel
()
separator
.
setStyleSheet
(
"border-top: 1px solid #dee2e6; margin: 5px 0;"
)
separator
.
setStyleSheet
(
"border-top: 1px solid #dee2e6; margin: 5px 0;"
)
...
@@ -646,12 +919,18 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -646,12 +919,18 @@ class TrainingPage(QtWidgets.QWidget):
# 创建别名以匹配训练处理器期望的名称
# 创建别名以匹配训练处理器期望的名称
self
.
train_status_label
=
self
.
status_label
self
.
train_status_label
=
self
.
status_label
# 控制按钮(使用Qt默认样式 + 响应式布局)
# 控制按钮(使用全局样式管理器)
self
.
start_train_btn
=
QtWidgets
.
QPushButton
(
"开始升级"
)
self
.
start_train_btn
=
TextButtonStyleManager
.
createStandardButton
(
"开始升级"
,
parent
=
self
)
self
.
start_train_btn
.
setMinimumWidth
(
scale_w
(
80
))
self
.
start_train_btn
.
setMinimumWidth
(
scale_w
(
80
))
control_layout
.
addWidget
(
self
.
start_train_btn
)
control_layout
.
addWidget
(
self
.
start_train_btn
)
self
.
stop_train_btn
=
QtWidgets
.
QPushButton
(
"停止升级"
)
self
.
stop_train_btn
=
TextButtonStyleManager
.
createStandardButton
(
"停止升级"
,
parent
=
self
)
self
.
stop_train_btn
.
setMinimumWidth
(
scale_w
(
80
))
self
.
stop_train_btn
.
setMinimumWidth
(
scale_w
(
80
))
self
.
stop_train_btn
.
setEnabled
(
False
)
self
.
stop_train_btn
.
setEnabled
(
False
)
control_layout
.
addWidget
(
self
.
stop_train_btn
)
control_layout
.
addWidget
(
self
.
stop_train_btn
)
...
@@ -663,7 +942,7 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -663,7 +942,7 @@ class TrainingPage(QtWidgets.QWidget):
# 🔥 应用全局字体管理器到所有文本框和控件
# 🔥 应用全局字体管理器到所有文本框和控件
if
FontManager
:
if
FontManager
:
# 应用到所有QLineEdit
# 应用到所有QLineEdit
FontManager
.
applyToWidget
(
self
.
base_model
_edit
)
FontManager
.
applyToWidget
(
self
.
dataset_paths
_edit
)
FontManager
.
applyToWidget
(
self
.
save_liquid_data_path_edit
)
FontManager
.
applyToWidget
(
self
.
save_liquid_data_path_edit
)
FontManager
.
applyToWidget
(
self
.
exp_name_edit
)
FontManager
.
applyToWidget
(
self
.
exp_name_edit
)
...
@@ -674,6 +953,7 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -674,6 +953,7 @@ class TrainingPage(QtWidgets.QWidget):
FontManager
.
applyToWidget
(
self
.
workers_spin
)
FontManager
.
applyToWidget
(
self
.
workers_spin
)
# 应用到所有QComboBox
# 应用到所有QComboBox
FontManager
.
applyToWidget
(
self
.
base_model_combo
)
FontManager
.
applyToWidget
(
self
.
device_combo
)
FontManager
.
applyToWidget
(
self
.
device_combo
)
FontManager
.
applyToWidget
(
self
.
optimizer_combo
)
FontManager
.
applyToWidget
(
self
.
optimizer_combo
)
...
@@ -686,39 +966,181 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -686,39 +966,181 @@ class TrainingPage(QtWidgets.QWidget):
FontManager
.
applyToWidget
(
self
.
checkbox_template_2
)
FontManager
.
applyToWidget
(
self
.
checkbox_template_2
)
FontManager
.
applyToWidget
(
self
.
checkbox_template_3
)
FontManager
.
applyToWidget
(
self
.
checkbox_template_3
)
# 应用到所有QPushButton
# 应用到状态标签(非按钮样式管理器创建的控件)
FontManager
.
applyToWidget
(
self
.
browse_model_btn
)
FontManager
.
applyToWidget
(
self
.
browse_data_btn
)
FontManager
.
applyToWidget
(
self
.
status_label
)
FontManager
.
applyToWidget
(
self
.
status_label
)
FontManager
.
applyToWidget
(
self
.
start_train_btn
)
FontManager
.
applyToWidget
(
self
.
stop_train_btn
)
# 应用到数据集列表
FontManager
.
applyToWidget
(
self
.
dataset_folders_list
)
# 应用到笔记按钮
FontManager
.
applyToWidget
(
self
.
training_notes_btn
)
# 应用到整个group(包括标签)
# 应用到整个group(包括标签)
FontManager
.
applyToWidgetRecursive
(
group
)
FontManager
.
applyToWidgetRecursive
(
group
)
return
group
return
group
def
_browseModel
(
self
):
def
_onBrowseDatasets
(
self
):
"""浏览模型文件"""
"""浏览数据集文件夹(支持多选)"""
file_path
,
_
=
QtWidgets
.
QFileDialog
.
getOpenFileName
(
# 使用简单的单选文件夹对话框,多次选择来实现多选效果
folder_path
=
QtWidgets
.
QFileDialog
.
getExistingDirectory
(
self
,
self
,
"选择基础模型"
,
"选择数据集文件夹"
,
"database/model"
,
"database/dataset"
if
os
.
path
.
exists
(
"database/dataset"
)
else
""
"模型文件 (*.pt *.dat);;所有文件 (*.*)"
)
)
if
file_path
:
self
.
base_model_edit
.
setText
(
file_path
)
if
folder_path
:
# 获取当前已有的文件夹路径
current_paths
=
self
.
dataset_paths_edit
.
text
()
.
strip
()
existing_folders
=
[
p
.
strip
()
for
p
in
current_paths
.
split
(
';'
)
if
p
.
strip
()]
if
current_paths
else
[]
# 检查是否已经添加过这个文件夹
if
folder_path
not
in
existing_folders
:
# 添加新文件夹
if
existing_folders
:
# 如果已有文件夹,用分号连接
new_paths
=
current_paths
+
';'
+
folder_path
else
:
# 如果是第一个文件夹
new_paths
=
folder_path
self
.
dataset_paths_edit
.
setText
(
new_paths
)
print
(
f
"[TrainingPage] 添加数据集文件夹: {folder_path}"
)
else
:
# 使用style_manager中的对话框管理器显示提示
try
:
from
..style_manager
import
DialogManager
DialogManager
.
show_information
(
self
,
"提示"
,
f
"文件夹已存在:
\n
{folder_path}"
)
except
ImportError
:
QtWidgets
.
QMessageBox
.
information
(
self
,
"提示"
,
f
"文件夹已存在:
\n
{folder_path}"
)
def
_browseDataset
(
self
):
def
_onDatasetPathsChanged
(
self
):
"""浏览数据集文件"""
"""数据集路径文本变化时的处理"""
file_path
,
_
=
QtWidgets
.
QFileDialog
.
getOpenFileName
(
# 更新内部数据集列表和隐藏的路径字段
self
,
paths_text
=
self
.
dataset_paths_edit
.
text
()
.
strip
()
"选择数据集配置"
,
folders
=
[
p
.
strip
()
for
p
in
paths_text
.
split
(
';'
)
if
p
.
strip
()]
"database/dataset"
,
"YAML文件 (*.yaml *.yml);;所有文件 (*.*)"
# 更新内部列表(用于兼容现有逻辑)
self
.
dataset_folders_list
.
clear
()
for
folder
in
folders
:
self
.
dataset_folders_list
.
addItem
(
folder
)
# 更新隐藏的路径字段
self
.
save_liquid_data_path_edit
.
setText
(
paths_text
)
# 调试信息
print
(
f
"[TrainingPage] 数据集路径更新: {len(folders)} 个文件夹"
)
for
i
,
folder
in
enumerate
(
folders
):
print
(
f
" [{i+1}] {folder}"
)
def
_addDatasetFolder
(
self
):
"""添加数据集文件夹(保留兼容性)"""
self
.
_onBrowseDatasets
()
def
_removeSelectedDatasets
(
self
):
"""删除选中的数据集文件夹"""
selected_items
=
self
.
dataset_folders_list
.
selectedItems
()
if
not
selected_items
:
QtWidgets
.
QMessageBox
.
information
(
self
,
"提示"
,
"请先选择要删除的文件夹"
)
return
for
item
in
selected_items
:
row
=
self
.
dataset_folders_list
.
row
(
item
)
self
.
dataset_folders_list
.
takeItem
(
row
)
self
.
_updateDatasetPath
()
def
_clearAllDatasets
(
self
):
"""清空所有数据集文件夹"""
if
self
.
dataset_folders_list
.
count
()
==
0
:
return
reply
=
QtWidgets
.
QMessageBox
.
question
(
self
,
"确认清空"
,
f
"确定要清空所有 {self.dataset_folders_list.count()} 个数据集文件夹吗?"
,
QtWidgets
.
QMessageBox
.
Yes
|
QtWidgets
.
QMessageBox
.
No
,
QtWidgets
.
QMessageBox
.
No
)
)
if
file_path
:
self
.
save_liquid_data_path_edit
.
setText
(
file_path
)
if
reply
==
QtWidgets
.
QMessageBox
.
Yes
:
self
.
dataset_folders_list
.
clear
()
self
.
_updateDatasetPath
()
def
_updateDatasetPath
(
self
):
"""更新隐藏的数据集路径字段(用分号分隔多个文件夹)"""
folders
=
[
self
.
dataset_folders_list
.
item
(
i
)
.
text
()
for
i
in
range
(
self
.
dataset_folders_list
.
count
())]
# 使用分号分隔多个文件夹路径
self
.
save_liquid_data_path_edit
.
setText
(
';'
.
join
(
folders
))
def
getDatasetFolders
(
self
):
"""获取所有数据集文件夹路径列表"""
paths_text
=
self
.
dataset_paths_edit
.
text
()
.
strip
()
return
[
p
.
strip
()
for
p
in
paths_text
.
split
(
';'
)
if
p
.
strip
()]
def
getTrainingNotes
(
self
):
"""获取训练笔记内容"""
return
self
.
_training_notes_content
.
strip
()
def
setTrainingNotes
(
self
,
notes
):
"""设置训练笔记内容"""
self
.
_training_notes_content
=
notes
if
notes
else
""
self
.
_updateNotesButtonText
()
def
clearTrainingNotes
(
self
):
"""清空训练笔记(不弹确认框)"""
self
.
_training_notes_content
=
""
self
.
_updateNotesButtonText
()
def
_openNotesDialog
(
self
):
"""打开训练笔记编辑对话框"""
dialog
=
TrainingNotesDialog
(
self
.
_training_notes_content
,
self
)
if
dialog
.
exec_
()
==
QtWidgets
.
QDialog
.
Accepted
:
self
.
_training_notes_content
=
dialog
.
getNotesContent
()
self
.
_updateNotesButtonText
()
def
_updateNotesButtonText
(
self
):
"""更新笔记按钮的显示文本"""
if
self
.
_training_notes_content
.
strip
():
# 显示笔记的前几个字符
preview
=
self
.
_training_notes_content
.
strip
()[:
8
]
# 减少字符数以适应按钮宽度
if
len
(
self
.
_training_notes_content
.
strip
())
>
8
:
preview
+=
"..."
new_text
=
f
"训练笔记 ({preview})"
# 使用全局样式管理器更新按钮文本和大小
TextButtonStyleManager
.
updateButtonText
(
self
.
training_notes_btn
,
new_text
)
# 添加有内容的视觉提示(保持全局样式基础上的微调)
current_style
=
self
.
training_notes_btn
.
styleSheet
()
self
.
training_notes_btn
.
setStyleSheet
(
current_style
+
"""
QPushButton {
background-color: #e3f2fd;
border: 1px solid #2196f3;
}
"""
)
else
:
# 使用全局样式管理器重置按钮
TextButtonStyleManager
.
updateButtonText
(
self
.
training_notes_btn
,
"训练笔记"
)
def
enableNotesButtons
(
self
):
"""启用笔记按钮(训练完成后调用)"""
# 训练笔记按钮始终可用,无需特殊处理
pass
def
disableNotesButtons
(
self
):
"""禁用笔记按钮(训练开始前调用)"""
# 训练笔记按钮始终可用,无需特殊处理
pass
@QtCore.Slot
(
str
)
@QtCore.Slot
(
str
)
def
appendLog
(
self
,
text
):
def
appendLog
(
self
,
text
):
...
@@ -898,10 +1320,11 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -898,10 +1320,11 @@ class TrainingPage(QtWidgets.QWidget):
self
.
stop_train_btn
.
setEnabled
(
is_training
)
self
.
stop_train_btn
.
setEnabled
(
is_training
)
# 禁用参数输入
# 禁用参数输入
self
.
base_model_edit
.
setEnabled
(
not
is_training
)
self
.
base_model_combo
.
setEnabled
(
not
is_training
)
self
.
browse_model_btn
.
setEnabled
(
not
is_training
)
self
.
dataset_folders_list
.
setEnabled
(
not
is_training
)
self
.
save_liquid_data_path_edit
.
setEnabled
(
not
is_training
)
self
.
add_dataset_btn
.
setEnabled
(
not
is_training
)
self
.
browse_data_btn
.
setEnabled
(
not
is_training
)
self
.
remove_dataset_btn
.
setEnabled
(
not
is_training
)
self
.
clear_dataset_btn
.
setEnabled
(
not
is_training
)
self
.
exp_name_edit
.
setEnabled
(
not
is_training
)
self
.
exp_name_edit
.
setEnabled
(
not
is_training
)
self
.
epochs_spin
.
setEnabled
(
not
is_training
)
self
.
epochs_spin
.
setEnabled
(
not
is_training
)
self
.
batch_spin
.
setEnabled
(
not
is_training
)
self
.
batch_spin
.
setEnabled
(
not
is_training
)
...
@@ -915,8 +1338,59 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -915,8 +1338,59 @@ class TrainingPage(QtWidgets.QWidget):
def
showEvent
(
self
,
event
):
def
showEvent
(
self
,
event
):
"""页面显示时刷新模型列表和测试文件列表(确保与模型集管理页面同步)"""
"""页面显示时刷新模型列表和测试文件列表(确保与模型集管理页面同步)"""
super
(
TrainingPage
,
self
)
.
showEvent
(
event
)
super
(
TrainingPage
,
self
)
.
showEvent
(
event
)
self
.
_loadBaseModelOptions
()
# 🔥 加载基础模型列表
self
.
_loadTestModelOptions
()
self
.
_loadTestModelOptions
()
self
.
_loadTestFileList
()
# 🔥 刷新测试文件列表
# self._loadTestFileList() # 🔥 不再需要刷新测试文件列表(改用浏览方式)
def
_loadBaseModelOptions
(
self
):
"""从模型集管理页面加载基础模型选项"""
# 清空现有选项
self
.
base_model_combo
.
clear
()
try
:
# 尝试从父窗口获取模型集页面
if
hasattr
(
self
.
_parent
,
'modelSetPage'
):
model_set_page
=
self
.
_parent
.
modelSetPage
# 获取模型参数字典
if
hasattr
(
model_set_page
,
'_model_params'
):
model_params
=
model_set_page
.
_model_params
if
not
model_params
:
self
.
base_model_combo
.
addItem
(
"未找到模型"
,
None
)
return
# 获取默认模型
default_model
=
None
if
hasattr
(
model_set_page
,
'_current_default_model'
):
default_model
=
model_set_page
.
_current_default_model
# 添加所有模型到下拉框
default_index
=
0
for
idx
,
(
model_name
,
params
)
in
enumerate
(
model_params
.
items
()):
model_path
=
params
.
get
(
'path'
,
''
)
# 构建显示名称
display_name
=
model_name
if
model_name
==
default_model
:
display_name
=
f
"{model_name} (默认)"
default_index
=
idx
# 添加到下拉框,使用模型路径作为数据
self
.
base_model_combo
.
addItem
(
display_name
,
model_path
)
# 设置默认选择
self
.
base_model_combo
.
setCurrentIndex
(
default_index
)
else
:
self
.
base_model_combo
.
addItem
(
"模型集页面未初始化"
,
None
)
else
:
self
.
base_model_combo
.
addItem
(
"未找到模型集页面"
,
None
)
except
Exception
as
e
:
print
(
f
"[基础模型] 加载失败: {e}"
)
import
traceback
traceback
.
print_exc
()
self
.
base_model_combo
.
addItem
(
"加载失败"
,
None
)
def
_loadTestModelOptions
(
self
):
def
_loadTestModelOptions
(
self
):
"""加载测试模型选项(从 train_model 文件夹读取)"""
"""加载测试模型选项(从 train_model 文件夹读取)"""
...
@@ -1118,91 +1592,10 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -1118,91 +1592,10 @@ class TrainingPage(QtWidgets.QWidget):
return
models
return
models
def
_loadTestFileList
(
self
):
def
_loadTestFileList
(
self
):
"""加载测试文件列表到下拉框"""
"""加载测试文件列表(保留方法以保持向后兼容,但现在使用浏览方式)"""
try
:
# 该方法现已改为使用浏览按钮选择文件,不再从固定目录加载
import
os
# 保留此方法以防其他代码调用
from
pathlib
import
Path
pass
# 🔥 改进:使用多种方式获取project_root
project_root
=
None
try
:
from
...database.config
import
get_project_root
project_root
=
get_project_root
()
# print(f"[测试文件] 使用相对导入获取project_root: {project_root}")
except
(
ImportError
,
ValueError
):
# 相对导入失败是预期行为(作为主程序运行时),静默处理
try
:
from
database.config
import
get_project_root
project_root
=
get_project_root
()
except
(
ImportError
,
ValueError
):
# 备选方案:使用当前文件的父目录
current_file
=
Path
(
__file__
)
.
resolve
()
project_root
=
current_file
.
parent
.
parent
.
parent
if
not
project_root
:
raise
RuntimeError
(
"无法获取project_root"
)
test_file_dir
=
os
.
path
.
join
(
str
(
project_root
),
'database'
,
'model_test_file'
)
# 清空下拉框
self
.
test_file_input
.
clear
()
if
not
os
.
path
.
exists
(
test_file_dir
):
self
.
test_file_input
.
addItem
(
"(目录不存在)"
)
return
# 扫描目录
items
=
[]
try
:
dir_items
=
os
.
listdir
(
test_file_dir
)
video_extensions
=
[
'.mp4'
,
'.avi'
,
'.mov'
,
'.mkv'
,
'.flv'
,
'.wmv'
]
image_extensions
=
[
'.jpg'
,
'.jpeg'
,
'.png'
,
'.bmp'
,
'.tiff'
,
'.webp'
]
for
item
in
dir_items
:
item_path
=
os
.
path
.
join
(
test_file_dir
,
item
)
# 检查是否为视频文件
if
os
.
path
.
isfile
(
item_path
):
if
any
(
item
.
lower
()
.
endswith
(
ext
)
for
ext
in
video_extensions
):
items
.
append
((
item
,
item_path
))
# 检查是否为文件夹
elif
os
.
path
.
isdir
(
item_path
):
# 检查文件夹内是否有图片或视频
has_images
=
False
has_videos
=
False
try
:
folder_items
=
os
.
listdir
(
item_path
)
for
file
in
folder_items
:
file_lower
=
file
.
lower
()
if
any
(
file_lower
.
endswith
(
ext
)
for
ext
in
image_extensions
):
has_images
=
True
break
elif
any
(
file_lower
.
endswith
(
ext
)
for
ext
in
video_extensions
):
has_videos
=
True
except
Exception
as
folder_e
:
pass
# 🔥 改进:文件夹包含图片或视频都添加
if
has_images
or
has_videos
:
items
.
append
((
f
"📁 {item}"
,
item_path
))
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
# 添加到下拉框
if
items
:
for
display_name
,
full_path
in
sorted
(
items
):
self
.
test_file_input
.
addItem
(
display_name
,
full_path
)
else
:
self
.
test_file_input
.
addItem
(
"(未找到测试文件)"
)
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
self
.
test_file_input
.
addItem
(
"(加载失败)"
)
def
_getRememberedTestModel
(
self
,
project_root
):
def
_getRememberedTestModel
(
self
,
project_root
):
"""从配置文件获取记忆的测试模型路径"""
"""从配置文件获取记忆的测试模型路径"""
...
@@ -1254,8 +1647,8 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -1254,8 +1647,8 @@ class TrainingPage(QtWidgets.QWidget):
def
getTestFilePath
(
self
):
def
getTestFilePath
(
self
):
"""获取选中的测试文件路径"""
"""获取选中的测试文件路径"""
#
🔥 改为从QComboBox获取数据
#
从 QLineEdit 获取文件路径
return
self
.
test_file_input
.
currentData
()
or
""
return
self
.
test_file_input
.
text
()
.
strip
()
def
isTestingInProgress
(
self
):
def
isTestingInProgress
(
self
):
"""检查是否正在测试中"""
"""检查是否正在测试中"""
...
@@ -1267,27 +1660,14 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -1267,27 +1660,14 @@ class TrainingPage(QtWidgets.QWidget):
if
is_testing
:
if
is_testing
:
# 切换为"停止测试"状态
# 切换为"停止测试"状态
self
.
start_test_btn
.
setText
(
"停止测试"
)
TextButtonStyleManager
.
updateButtonText
(
self
.
start_test_btn
,
"停止测试"
)
# 使用红色样式表示危险操作
# 使用全局样式管理器的危险按钮样式
self
.
start_test_btn
.
setStyleSheet
(
"""
TextButtonStyleManager
.
applyDangerStyle
(
self
.
start_test_btn
)
QPushButton {
background-color: #dc3545;
color: white;
border: none;
padding: 8px 16px;
border-radius: 4px;
font-weight: bold;
min-width: 100px;
}
QPushButton:hover {
background-color: #c82333;
}
"""
)
else
:
else
:
# 切换为"开始测试"状态
# 切换为"开始测试"状态
self
.
start_test_btn
.
setText
(
"开始测试"
)
TextButtonStyleManager
.
updateButtonText
(
self
.
start_test_btn
,
"开始测试"
)
# 恢复
默认
样式
# 恢复
标准
样式
self
.
start_test_btn
.
setStyleSheet
(
""
)
TextButtonStyleManager
.
applyStandardStyle
(
self
.
start_test_btn
)
def
_onTemplateChecked
(
self
,
button
):
def
_onTemplateChecked
(
self
,
button
):
"""处理模板复选框选中事件"""
"""处理模板复选框选中事件"""
...
@@ -1326,8 +1706,14 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -1326,8 +1706,14 @@ class TrainingPage(QtWidgets.QWidget):
try
:
try
:
# 基础模型
# 基础模型
if
'model'
in
config
:
if
'model'
in
config
:
self
.
base_model_edit
.
setText
(
str
(
config
[
'model'
]))
model_path
=
str
(
config
[
'model'
])
print
(
f
"[模板] 设置基础模型: {config['model']}"
)
# 在下拉菜单中查找匹配的模型
for
i
in
range
(
self
.
base_model_combo
.
count
()):
item_data
=
self
.
base_model_combo
.
itemData
(
i
)
if
item_data
and
item_data
==
model_path
:
self
.
base_model_combo
.
setCurrentIndex
(
i
)
print
(
f
"[模板] 设置基础模型: {model_path}"
)
break
# 数据集配置
# 数据集配置
if
'data'
in
config
:
if
'data'
in
config
:
...
@@ -1392,6 +1778,134 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -1392,6 +1778,134 @@ class TrainingPage(QtWidgets.QWidget):
import
traceback
import
traceback
traceback
.
print_exc
()
traceback
.
print_exc
()
def
_browseTestFile
(
self
):
"""浏览选择测试文件(支持图片和视频)"""
try
:
# 定义支持的文件类型
image_formats
=
"图片文件 (*.jpg *.jpeg *.png *.bmp *.tiff *.webp)"
video_formats
=
"视频文件 (*.mp4 *.avi *.mov *.mkv *.flv *.wmv)"
all_formats
=
"所有支持的文件 (*.jpg *.jpeg *.png *.bmp *.tiff *.webp *.mp4 *.avi *.mov *.mkv *.flv *.wmv)"
# 构建文件过滤器
file_filter
=
f
"{all_formats};;{image_formats};;{video_formats};;所有文件 (*.*)"
# 打开文件选择对话框
file_path
,
_
=
QtWidgets
.
QFileDialog
.
getOpenFileName
(
self
,
"选择测试图片或视频文件"
,
""
,
# 默认目录为空,使用系统默认
file_filter
)
# 如果用户选择了文件,则设置到输入框
if
file_path
:
self
.
test_file_input
.
setText
(
file_path
)
print
(
f
"[测试文件] 已选择: {file_path}"
)
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
QtWidgets
.
QMessageBox
.
warning
(
self
,
"文件选择失败"
,
f
"选择测试文件时发生错误:
\n
{str(e)}"
)
def
_onViewCurveClicked
(
self
):
"""查看曲线按钮点击处理"""
try
:
# 检查是否有曲线数据
if
not
hasattr
(
self
,
'curve_data_x'
)
or
not
hasattr
(
self
,
'curve_data_y'
):
self
.
_showNoCurveMessage
()
return
if
len
(
self
.
curve_data_x
)
==
0
:
self
.
_showNoCurveMessage
()
return
# 切换到曲线面板显示
self
.
showCurvePanel
()
# 显示曲线信息提示
data_count
=
len
(
self
.
curve_data_x
)
if
data_count
==
1
:
# 图片测试
liquid_level
=
self
.
curve_data_y
[
0
]
QtWidgets
.
QMessageBox
.
information
(
self
,
"曲线信息"
,
f
"图片测试结果:
\n
液位高度: {liquid_level:.1f} mm
\n\n
"
f
"曲线已显示在左侧面板中。"
)
else
:
# 视频测试
min_level
=
min
(
self
.
curve_data_y
)
max_level
=
max
(
self
.
curve_data_y
)
avg_level
=
sum
(
self
.
curve_data_y
)
/
len
(
self
.
curve_data_y
)
QtWidgets
.
QMessageBox
.
information
(
self
,
"曲线信息"
,
f
"视频测试结果:
\n
"
f
"数据点数: {data_count} 个
\n
"
f
"液位范围: {min_level:.1f} - {max_level:.1f} mm
\n
"
f
"平均液位: {avg_level:.1f} mm
\n\n
"
f
"曲线已显示在左侧面板中。"
)
except
Exception
as
e
:
print
(
f
"[查看曲线] 显示曲线失败: {e}"
)
QtWidgets
.
QMessageBox
.
warning
(
self
,
"显示失败"
,
f
"显示曲线时发生错误:
\n
{str(e)}"
)
def
_showNoCurveMessage
(
self
):
"""显示无曲线数据的提示"""
QtWidgets
.
QMessageBox
.
information
(
self
,
"无曲线数据"
,
"当前没有可显示的曲线数据。
\n\n
"
"请先进行模型测试:
\n
"
"1. 选择测试模型
\n
"
"2. 选择测试文件
\n
"
"3. 点击
\"
开始标注
\"\n
"
"4. 点击
\"
开始测试
\"\n\n
"
"测试完成后即可查看曲线结果。"
)
def
enableViewCurveButton
(
self
):
"""启用查看曲线按钮(测试完成后调用)"""
try
:
self
.
view_curve_btn
.
setEnabled
(
True
)
self
.
view_curve_btn
.
setToolTip
(
"点击查看测试结果曲线"
)
# 检查曲线数据类型并更新按钮文本
if
hasattr
(
self
,
'curve_data_x'
)
and
len
(
self
.
curve_data_x
)
>
0
:
data_count
=
len
(
self
.
curve_data_x
)
if
data_count
==
1
:
self
.
view_curve_btn
.
setText
(
"查看曲线(图)"
)
else
:
self
.
view_curve_btn
.
setText
(
"查看曲线(视)"
)
else
:
self
.
view_curve_btn
.
setText
(
"查看曲线"
)
print
(
f
"[查看曲线] 按钮已启用,数据点数: {len(self.curve_data_x) if hasattr(self, 'curve_data_x') else 0}"
)
except
Exception
as
e
:
print
(
f
"[查看曲线] 启用按钮失败: {e}"
)
def
disableViewCurveButton
(
self
):
"""禁用查看曲线按钮(测试开始前调用)"""
try
:
self
.
view_curve_btn
.
setEnabled
(
False
)
self
.
view_curve_btn
.
setText
(
"查看曲线"
)
self
.
view_curve_btn
.
setToolTip
(
"测试完成后可查看曲线结果"
)
print
(
f
"[查看曲线] 按钮已禁用"
)
except
Exception
as
e
:
print
(
f
"[查看曲线] 禁用按钮失败: {e}"
)
def
getTemplateConfig
(
self
):
def
getTemplateConfig
(
self
):
"""获取当前选中的模板配置名称"""
"""获取当前选中的模板配置名称"""
if
self
.
checkbox_template_1
.
isChecked
():
if
self
.
checkbox_template_1
.
isChecked
():
...
@@ -1414,3 +1928,148 @@ if __name__ == '__main__':
...
@@ -1414,3 +1928,148 @@ if __name__ == '__main__':
page
.
show
()
page
.
show
()
sys
.
exit
(
app
.
exec_
())
sys
.
exit
(
app
.
exec_
())
class
TrainingNotesDialog
(
QtWidgets
.
QDialog
):
"""训练笔记编辑对话框"""
def
__init__
(
self
,
initial_content
=
""
,
parent
=
None
):
super
()
.
__init__
(
parent
)
self
.
setWindowTitle
(
"训练笔记编辑"
)
self
.
setMinimumSize
(
500
,
400
)
self
.
resize
(
600
,
500
)
# 设置窗口图标(如果有的话)
self
.
setWindowFlags
(
self
.
windowFlags
()
&
~
QtCore
.
Qt
.
WindowContextHelpButtonHint
)
# 应用全局样式管理器
from
..style_manager
import
FontManager
,
BackgroundStyleManager
BackgroundStyleManager
.
applyToWidget
(
self
)
self
.
_setupUI
()
self
.
text_edit
.
setPlainText
(
initial_content
)
# 应用全局字体管理器到整个对话框
FontManager
.
applyToDialog
(
self
)
def
_setupUI
(
self
):
"""设置UI界面"""
layout
=
QtWidgets
.
QVBoxLayout
(
self
)
layout
.
setSpacing
(
10
)
layout
.
setContentsMargins
(
15
,
15
,
15
,
15
)
# 标题标签(使用全局字体管理器)
from
..style_manager
import
FontManager
title_label
=
QtWidgets
.
QLabel
(
"训练笔记"
)
title_label
.
setFont
(
FontManager
.
getTitleFont
())
title_label
.
setStyleSheet
(
"""
QLabel {
color: #333;
margin-bottom: 5px;
}
"""
)
layout
.
addWidget
(
title_label
)
# 说明文字(使用全局字体管理器)
info_label
=
QtWidgets
.
QLabel
(
"在此记录本次训练的相关信息,如数据集变化、参数调整原因、预期效果等..."
)
info_label
.
setFont
(
FontManager
.
getSmallFont
())
info_label
.
setStyleSheet
(
"""
QLabel {
color: #666;
margin-bottom: 10px;
}
"""
)
info_label
.
setWordWrap
(
True
)
layout
.
addWidget
(
info_label
)
# 文本编辑区域(使用全局字体管理器)
self
.
text_edit
=
QtWidgets
.
QTextEdit
()
self
.
text_edit
.
setFont
(
FontManager
.
getMediumFont
())
self
.
text_edit
.
setStyleSheet
(
"""
QTextEdit {
border: 1px solid #ccc;
background-color: white;
padding: 10px;
line-height: 1.4;
}
"""
)
self
.
text_edit
.
setPlaceholderText
(
"示例内容:
\n
"
"• 数据集:新增了100张液位图片
\n
"
"• 参数调整:学习率从0.01调整为0.005
\n
"
"• 预期效果:提高小液位目标的检测精度
\n
"
"• 其他备注:..."
)
layout
.
addWidget
(
self
.
text_edit
)
# 字符计数标签(使用全局字体管理器)
self
.
char_count_label
=
QtWidgets
.
QLabel
(
"字符数: 0"
)
self
.
char_count_label
.
setFont
(
FontManager
.
getSmallFont
())
self
.
char_count_label
.
setStyleSheet
(
"color: #666;"
)
self
.
char_count_label
.
setAlignment
(
QtCore
.
Qt
.
AlignRight
)
layout
.
addWidget
(
self
.
char_count_label
)
# 按钮区域(使用全局样式管理器)
button_layout
=
QtWidgets
.
QHBoxLayout
()
button_layout
.
addStretch
()
# 清空按钮
clear_btn
=
TextButtonStyleManager
.
createStandardButton
(
"清空"
,
self
,
self
.
_clearText
)
button_layout
.
addWidget
(
clear_btn
)
# 取消按钮
cancel_btn
=
TextButtonStyleManager
.
createStandardButton
(
"取消"
,
self
,
self
.
reject
)
button_layout
.
addWidget
(
cancel_btn
)
# 保存按钮(使用全局样式管理器创建,然后添加特殊样式)
save_btn
=
TextButtonStyleManager
.
createStandardButton
(
"保存"
,
self
,
self
.
accept
)
save_btn
.
setDefault
(
True
)
save_btn
.
setStyleSheet
(
"""
QPushButton {
background-color: #2196f3;
color: white;
border: none;
padding: 8px;
font-weight: bold;
}
QPushButton:hover {
background-color: #1976d2;
}
"""
)
button_layout
.
addWidget
(
save_btn
)
layout
.
addLayout
(
button_layout
)
# 连接信号
self
.
text_edit
.
textChanged
.
connect
(
self
.
_updateCharCount
)
self
.
_updateCharCount
()
def
_clearText
(
self
):
"""清空文本"""
if
self
.
text_edit
.
toPlainText
()
.
strip
():
from
..style_manager
import
DialogManager
if
DialogManager
.
show_question_warning
(
self
,
"确认清空"
,
"确定要清空所有笔记内容吗?"
,
"是"
,
"否"
):
self
.
text_edit
.
clear
()
def
_updateCharCount
(
self
):
"""更新字符计数"""
text
=
self
.
text_edit
.
toPlainText
()
char_count
=
len
(
text
)
self
.
char_count_label
.
setText
(
f
"字符数: {char_count}"
)
# 字符数过多时显示警告颜色(保持全局字体设置)
if
char_count
>
1000
:
self
.
char_count_label
.
setStyleSheet
(
"color: #f44336;"
)
elif
char_count
>
500
:
self
.
char_count_label
.
setStyleSheet
(
"color: #ff9800;"
)
else
:
self
.
char_count_label
.
setStyleSheet
(
"color: #666;"
)
def
getNotesContent
(
self
):
"""获取笔记内容"""
return
self
.
text_edit
.
toPlainText
()
.
strip
()
widgets/responsive_layout.py
View file @
da148b05
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
根据屏幕分辨率自动调整UI尺寸
根据屏幕分辨率自动调整UI尺寸
"""
"""
from
PyQt5
import
QtWidgets
,
QtCore
from
qtpy
import
QtWidgets
,
QtCore
class
ResponsiveLayout
:
class
ResponsiveLayout
:
...
...
widgets/style_manager.py
View file @
da148b05
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""
"""
统一样式管理器
统一样式管理器
...
@@ -40,7 +40,7 @@ import os.path as osp
...
@@ -40,7 +40,7 @@ import os.path as osp
import
sys
import
sys
try
:
try
:
from
PyQt5
import
QtGui
,
QtWidgets
,
QtCore
from
qtpy
import
QtGui
,
QtWidgets
,
QtCore
except
ImportError
as
e
:
except
ImportError
as
e
:
raise
raise
...
@@ -78,12 +78,11 @@ class FontManager:
...
@@ -78,12 +78,11 @@ class FontManager:
@staticmethod
@staticmethod
def
getDefaultFont
():
def
getDefaultFont
():
"""获取默认字体"""
"""获取默认字体"""
font
=
QtGui
.
QFont
(
return
FontManager
.
getFont
(
FontManager
.
DEFAULT_FONT_FAMILY
,
size
=
FontManager
.
DEFAULT_FONT_SIZE
,
FontManager
.
DEFAULT_FONT_SIZE
weight
=
FontManager
.
DEFAULT_FONT_WEIGHT
,
family
=
FontManager
.
DEFAULT_FONT_FAMILY
)
)
font
.
setWeight
(
FontManager
.
DEFAULT_FONT_WEIGHT
)
return
font
@staticmethod
@staticmethod
def
getFont
(
size
=
None
,
weight
=
None
,
italic
=
False
,
underline
=
False
,
family
=
None
):
def
getFont
(
size
=
None
,
weight
=
None
,
italic
=
False
,
underline
=
False
,
family
=
None
):
...
@@ -96,7 +95,31 @@ class FontManager:
...
@@ -96,7 +95,31 @@ class FontManager:
weight
=
FontManager
.
DEFAULT_FONT_WEIGHT
weight
=
FontManager
.
DEFAULT_FONT_WEIGHT
font
=
QtGui
.
QFont
(
family
,
size
)
font
=
QtGui
.
QFont
(
family
,
size
)
font
.
setWeight
(
weight
)
# PySide6兼容性:将整数权重转换为QFont.Weight枚举
try
:
# 尝试使用PySide6的Weight枚举
if
hasattr
(
QtGui
.
QFont
,
'Weight'
):
if
weight
<=
25
:
font
.
setWeight
(
QtGui
.
QFont
.
Weight
.
Light
)
elif
weight
<=
50
:
font
.
setWeight
(
QtGui
.
QFont
.
Weight
.
Normal
)
elif
weight
<=
63
:
font
.
setWeight
(
QtGui
.
QFont
.
Weight
.
DemiBold
)
elif
weight
<=
75
:
font
.
setWeight
(
QtGui
.
QFont
.
Weight
.
Bold
)
else
:
font
.
setWeight
(
QtGui
.
QFont
.
Weight
.
Black
)
else
:
# 回退到旧的整数方式(PyQt5)
font
.
setWeight
(
weight
)
except
(
AttributeError
,
TypeError
):
# 如果出现任何错误,使用默认权重
try
:
font
.
setWeight
(
QtGui
.
QFont
.
Weight
.
Normal
)
except
:
pass
font
.
setItalic
(
italic
)
font
.
setItalic
(
italic
)
font
.
setUnderline
(
underline
)
font
.
setUnderline
(
underline
)
return
font
return
font
...
@@ -157,8 +180,17 @@ class FontManager:
...
@@ -157,8 +180,17 @@ class FontManager:
@staticmethod
@staticmethod
def
applyToApplication
(
app
):
def
applyToApplication
(
app
):
"""应用默认字体到整个应用程序"""
"""应用默认字体到整个应用程序"""
font
=
FontManager
.
getDefaultFont
()
try
:
app
.
setFont
(
font
)
# 确保使用当前 Qt 后端创建字体对象
from
qtpy
import
QtGui
font
=
QtGui
.
QFont
(
FontManager
.
DEFAULT_FONT_FAMILY
,
FontManager
.
DEFAULT_FONT_SIZE
)
font
.
setWeight
(
FontManager
.
DEFAULT_FONT_WEIGHT
)
app
.
setFont
(
font
)
except
Exception
as
e
:
print
(
f
"字体设置失败,使用默认字体: {e}"
)
@staticmethod
@staticmethod
def
applyToWidgetRecursive
(
widget
,
size
=
None
,
weight
=
None
):
def
applyToWidgetRecursive
(
widget
,
size
=
None
,
weight
=
None
):
...
@@ -827,7 +859,7 @@ class TextButtonStyleManager:
...
@@ -827,7 +859,7 @@ class TextButtonStyleManager:
# 计算文本宽度:字符数 * 每字符宽度 + 内边距
# 计算文本宽度:字符数 * 每字符宽度 + 内边距
text_width
=
len
(
text
)
*
cls
.
CHAR_WIDTH
+
cls
.
MIN_PADDING
text_width
=
len
(
text
)
*
cls
.
CHAR_WIDTH
+
cls
.
MIN_PADDING
# 返回基础宽度和计算宽度的较大值
return
max
(
cls
.
BASE_WIDTH
,
text_width
)
return
max
(
cls
.
BASE_WIDTH
,
text_width
)
@classmethod
@classmethod
...
@@ -850,7 +882,14 @@ class TextButtonStyleManager:
...
@@ -850,7 +882,14 @@ class TextButtonStyleManager:
@classmethod
@classmethod
def
createStandardButton
(
cls
,
text
,
parent
=
None
,
slot
=
None
):
def
createStandardButton
(
cls
,
text
,
parent
=
None
,
slot
=
None
):
"""创建标准样式的文本按钮"""
"""创建标准样式的文本按钮"""
button
=
QtWidgets
.
QPushButton
(
text
,
parent
)
# 修复 PySide6 兼容性问题 - 分步创建
try
:
button
=
QtWidgets
.
QPushButton
(
text
)
if
parent
is
not
None
:
button
.
setParent
(
parent
)
except
Exception
as
e
:
print
(
f
"按钮创建失败: {e}"
)
button
=
QtWidgets
.
QPushButton
(
"按钮"
)
# 应用标准样式
# 应用标准样式
cls
.
applyToButton
(
button
,
text
)
cls
.
applyToButton
(
button
,
text
)
...
@@ -866,6 +905,37 @@ class TextButtonStyleManager:
...
@@ -866,6 +905,37 @@ class TextButtonStyleManager:
"""更新按钮文本并重新调整大小"""
"""更新按钮文本并重新调整大小"""
button
.
setText
(
new_text
)
button
.
setText
(
new_text
)
cls
.
applyToButton
(
button
,
new_text
)
cls
.
applyToButton
(
button
,
new_text
)
@classmethod
def
applyDangerStyle
(
cls
,
button
):
"""应用危险按钮样式(红色)"""
button
.
setStyleSheet
(
"""
QPushButton {
background-color: #dc3545;
color: white;
border: none;
padding: 8px 16px;
border-radius: 4px;
font-weight: bold;
min-width: 100px;
}
QPushButton:hover {
background-color: #c82333;
}
QPushButton:pressed {
background-color: #bd2130;
}
QPushButton:disabled {
background-color: #6c757d;
color: #ffffff;
}
"""
)
@classmethod
def
applyStandardStyle
(
cls
,
button
):
"""应用标准按钮样式"""
# 重新应用标准样式
cls
.
applyToButton
(
button
)
class
BackgroundStyleManager
:
class
BackgroundStyleManager
:
"""全局背景颜色管理器"""
"""全局背景颜色管理器"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment