Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
O
Oil_Level_Recognition_System
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Administrator
Oil_Level_Recognition_System
Commits
7e155e88
Commit
7e155e88
authored
Nov 29, 2025
by
wangbing2
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
123
parent
1e35f3f8
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
1518 additions
and
299 deletions
+1518
-299
model_set_handler.py
handlers/modelpage/model_set_handler.py
+1
-1
model_test_handler.py
handlers/modelpage/model_test_handler.py
+69
-5
model_training_handler.py
handlers/modelpage/model_training_handler.py
+228
-74
model_trainingworker_handler.py
handlers/modelpage/model_trainingworker_handler.py
+140
-63
modelset_page.py
widgets/modelpage/modelset_page.py
+509
-47
training_page.py
widgets/modelpage/training_page.py
+463
-107
style_manager.py
widgets/style_manager.py
+108
-2
No files found.
handlers/modelpage/model_set_handler.py
View file @
7e155e88
...
...
@@ -705,7 +705,7 @@ class ModelSetHandler:
if
model_path
.
endswith
(
'.pt'
):
model_type
=
"PyTorch (YOLOv5/v8)"
elif
model_path
.
endswith
(
'.dat'
):
model_type
=
"
加密模型 (.dat)
"
model_type
=
"
.dat
"
else
:
model_type
=
"未知格式"
...
...
handlers/modelpage/model_test_handler.py
View file @
7e155e88
...
...
@@ -136,6 +136,7 @@ class ModelTestThread(QThread):
# 读取标注数据
self
.
progress_updated
.
emit
(
45
,
"正在读取标注数据..."
)
import
yaml
with
open
(
annotation_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
try
:
annotation_data
=
yaml
.
safe_load
(
f
)
...
...
@@ -157,20 +158,31 @@ class ModelTestThread(QThread):
# 配置标注数据
self
.
progress_updated
.
emit
(
70
,
"正在配置检测参数..."
)
detection_engine
.
set_annotation_data
(
annotation_data
)
# 从annotation_data中提取配置参数
boxes
=
annotation_data
.
get
(
'boxes'
,
[])
fixed_bottoms
=
annotation_data
.
get
(
'fixed_bottoms'
,
[])
fixed_tops
=
annotation_data
.
get
(
'fixed_tops'
,
[])
actual_heights
=
annotation_data
.
get
(
'actual_heights'
,
[
20.0
]
*
len
(
boxes
))
detection_engine
.
configure
(
boxes
,
fixed_bottoms
,
fixed_tops
,
actual_heights
)
# 执行检测
self
.
progress_updated
.
emit
(
80
,
"正在执行液位检测..."
)
detection_result
=
detection_engine
.
detect
_liquid_level
(
test_frame
)
detection_result
=
detection_engine
.
detect
(
test_frame
)
if
detection_result
is
None
:
raise
RuntimeError
(
"检测结果为空"
)
self
.
_detection_result
=
detection_result
# 转换检测结果格式以兼容现有代码
converted_result
=
self
.
_convertDetectionResult
(
detection_result
)
self
.
_detection_result
=
converted_result
self
.
progress_updated
.
emit
(
90
,
"正在保存测试结果..."
)
self
.
_saveImageTestResults
(
model_path
,
test_frame
,
detection_result
,
annotation_file
)
# 通过handler调用保存方法
handler
=
self
.
test_params
[
'handler'
]
handler
.
_saveImageTestResults
(
model_path
,
test_frame
,
converted_result
,
annotation_file
)
except
Exception
as
e
:
raise
...
...
@@ -198,6 +210,59 @@ class ModelTestThread(QThread):
"""获取检测结果"""
return
self
.
_detection_result
def
_convertDetectionResult
(
self
,
detection_result
):
"""
转换检测结果格式以兼容现有代码
Args:
detection_result: detect方法返回的结果格式
{
'liquid_line_positions': {
0: {'y': y坐标, 'height_mm': 高度毫米, 'height_px': 高度像素},
1: {...},
...
},
'success': bool
}
Returns:
dict: 转换后的结果格式,包含 liquid_level_mm 字段
"""
try
:
if
not
detection_result
or
not
detection_result
.
get
(
'success'
,
False
):
return
{
'liquid_level_mm'
:
0.0
,
'success'
:
False
}
liquid_positions
=
detection_result
.
get
(
'liquid_line_positions'
,
{})
if
not
liquid_positions
:
return
{
'liquid_level_mm'
:
0.0
,
'success'
:
False
}
# 取第一个检测区域的液位高度
first_position
=
next
(
iter
(
liquid_positions
.
values
()))
liquid_level_mm
=
first_position
.
get
(
'height_mm'
,
0.0
)
# 构建兼容格式的结果
converted_result
=
{
'liquid_level_mm'
:
liquid_level_mm
,
'success'
:
True
,
'areas'
:
{}
}
# 转换所有区域的数据
for
idx
,
position_data
in
liquid_positions
.
items
():
area_name
=
f
'区域{idx + 1}'
converted_result
[
'areas'
][
area_name
]
=
{
'liquid_height'
:
position_data
.
get
(
'height_mm'
,
0.0
),
'y_position'
:
position_data
.
get
(
'y'
,
0
),
'height_px'
:
position_data
.
get
(
'height_px'
,
0
)
}
return
converted_result
except
Exception
as
e
:
print
(
f
"[结果转换] 转换检测结果失败: {e}"
)
return
{
'liquid_level_mm'
:
0.0
,
'success'
:
False
}
class
ModelTestHandler
:
"""
...
...
@@ -1777,7 +1842,6 @@ class ModelTestHandler:
<h4 style="margin: 0 0 10px 0; color: #333333; font-size: 16px; font-weight: 500;">检测结果视频</h4>
<video width="100
%
" height="auto" controls style="border: none; border-radius: 6px; max-height: 400px; background: #f8f9fa;">
<source src="file:///{video_path_formatted}" type="video/mp4">
您的浏览器不支持视频播放
</video>
</div>
...
...
handlers/modelpage/model_training_handler.py
View file @
7e155e88
...
...
@@ -563,18 +563,17 @@ class ModelTrainingHandler(ModelTestHandler):
except
Exception
as
add_error
:
self
.
_appendLog
(
f
"[WARNING] 添加到模型集失败: {str(add_error)}
\n
"
)
#
刷新模型集管理
页面
#
同步到所有相关
页面
try
:
self
.
_
refreshModelSetPage
()
except
Exception
as
refresh
_error
:
pass
self
.
_
syncAllModelPages
()
except
Exception
as
sync
_error
:
self
.
_appendLog
(
f
"[WARNING] 同步页面失败: {str(sync_error)}
\n
"
)
try
:
self
.
_refreshModelTestPage
()
except
Exception
as
test_refresh_error
:
pass
# 修复:使用非阻塞式通知,避免卡住UI
# 获取 last.dat 路径(转换后应该是 dat 文件)
last_checkpoint_path
=
None
if
self
.
current_exp_name
and
weights_dir
:
last_dat_path
=
os
.
path
.
join
(
weights_dir
,
"last.dat"
)
last_pt_path
=
os
.
path
.
join
(
weights_dir
,
"last.pt"
)
self
.
_appendLog
(
"
\n
"
+
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
" 训练完成通知
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
...
...
@@ -2388,12 +2387,12 @@ class ModelTrainingHandler(ModelTestHandler):
# 获取训练笔记
training_notes
=
self
.
_getTrainingNotes
()
#
将模型移动到detection_model目录
detection_model_path
=
self
.
_
moveModelTo
DetectionDir
(
best_model_path
,
model_name
,
weights_dir
,
training_notes
)
#
模型已直接保存在detection_model目录中,进行后处理
detection_model_path
=
self
.
_
processModelIn
DetectionDir
(
best_model_path
,
model_name
,
weights_dir
,
training_notes
)
if
detection_model_path
:
# 更新模型参数中的路径
model_params
[
'path'
]
=
detection_model_path
self
.
_appendLog
(
f
"模型已
移动到
detection_model目录: {os.path.basename(detection_model_path)}
\n
"
)
self
.
_appendLog
(
f
"模型已
保存在
detection_model目录: {os.path.basename(detection_model_path)}
\n
"
)
# 保存到模型集配置
self
.
_saveModelToConfig
(
model_name
,
model_params
)
...
...
@@ -2524,62 +2523,32 @@ class ModelTrainingHandler(ModelTestHandler):
self
.
_appendLog
(
f
"[ERROR] 获取最新模型目录失败: {str(e)}
\n
"
)
return
None
def
_
moveModelTo
DetectionDir
(
self
,
model_path
,
model_name
,
weights_dir
,
training_notes
=
""
):
"""
将训练完成的模型移动到detection_model目录
"""
def
_
processModelIn
DetectionDir
(
self
,
model_path
,
model_name
,
weights_dir
,
training_notes
=
""
):
"""
处理已保存在detection_model目录中的模型
"""
try
:
import
shutil
from
pathlib
import
Path
self
.
_appendLog
(
f
"
\n
开始移动模型到detection_model目录...
\n
"
)
# 获取项目根目录
project_root
=
get_project_root
()
detection_model_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'detection_model'
)
# 确保detection_model目录存在
os
.
makedirs
(
detection_model_dir
,
exist_ok
=
True
)
# 获取下一个可用的数字ID
existing_dirs
=
[]
for
item
in
os
.
listdir
(
detection_model_dir
):
item_path
=
os
.
path
.
join
(
detection_model_dir
,
item
)
if
os
.
path
.
isdir
(
item_path
)
and
item
.
isdigit
():
existing_dirs
.
append
(
int
(
item
))
next_id
=
max
(
existing_dirs
)
+
1
if
existing_dirs
else
1
target_model_dir
=
os
.
path
.
join
(
detection_model_dir
,
str
(
next_id
))
self
.
_appendLog
(
f
"
\n
开始处理detection_model目录中的模型...
\n
"
)
self
.
_appendLog
(
f
" 目标目录: {target_model_dir}
\n
"
)
# 创建目标目录结构
os
.
makedirs
(
target_model_dir
,
exist_ok
=
True
)
target_weights_dir
=
os
.
path
.
join
(
target_model_dir
,
'weights'
)
os
.
makedirs
(
target_weights_dir
,
exist_ok
=
True
)
# 移动整个weights目录的内容
# 模型已经直接保存在detection_model目录中,获取模型目录
if
weights_dir
and
os
.
path
.
exists
(
weights_dir
):
self
.
_appendLog
(
f
" 复制weights目录内容...
\n
"
)
# 复制所有文件到目标weights目录
for
filename
in
os
.
listdir
(
weights_dir
):
source_file
=
os
.
path
.
join
(
weights_dir
,
filename
)
target_file
=
os
.
path
.
join
(
target_weights_dir
,
filename
)
# weights_dir应该是 detection_model/{ID}/weights
target_model_dir
=
os
.
path
.
dirname
(
weights_dir
)
model_id
=
os
.
path
.
basename
(
target_model_dir
)
if
os
.
path
.
isfile
(
source_file
):
shutil
.
copy2
(
source_file
,
target_file
)
self
.
_appendLog
(
f
" 复制: {filename}
\n
"
)
self
.
_appendLog
(
f
" 模型目录: {target_model_dir}
\n
"
)
self
.
_appendLog
(
f
" 模型ID: {model_id}
\n
"
)
# 复制训练目录的其他文件(如config.yaml, results.csv等)
train_exp_dir
=
os
.
path
.
dirname
(
weights_dir
)
if
os
.
path
.
exists
(
train_exp_dir
):
for
filename
in
os
.
listdir
(
train_exp_dir
):
if
filename
!=
'weights'
:
# 跳过weights目录
source_file
=
os
.
path
.
join
(
train_exp_dir
,
filename
)
target_file
=
os
.
path
.
join
(
target_model_dir
,
filename
)
# 清理weights目录中的所有pt文件
self
.
_appendLog
(
f
" 清理pt文件...
\n
"
)
self
.
_forceCleanupPtFiles
(
weights_dir
)
if
os
.
path
.
isfile
(
source_file
):
shutil
.
copy2
(
source_file
,
target_file
)
self
.
_appendLog
(
f
" 复制配置: {filename}
\n
"
)
# 验证pt文件已被清理
remaining_pt_files
=
[
f
for
f
in
os
.
listdir
(
weights_dir
)
if
f
.
endswith
(
'.pt'
)]
if
remaining_pt_files
:
self
.
_appendLog
(
f
" 警告: 仍有pt文件未清理: {remaining_pt_files}
\n
"
)
else
:
self
.
_appendLog
(
f
" ✅ 所有pt文件已清理完成
\n
"
)
# 保存训练笔记(如果有)
if
training_notes
:
...
...
@@ -2589,28 +2558,46 @@ class ModelTrainingHandler(ModelTestHandler):
# 添加时间戳和模型信息
f
.
write
(
f
"训练笔记 - {model_name}
\n
"
)
f
.
write
(
f
"训练时间: {self._getCurrentTimestamp()}
\n
"
)
f
.
write
(
f
"模型ID: {next
_id}
\n
"
)
f
.
write
(
f
"模型ID: {model
_id}
\n
"
)
f
.
write
(
"="
*
50
+
"
\n\n
"
)
f
.
write
(
training_notes
)
self
.
_appendLog
(
f
" 保存训练笔记: training_notes.txt
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
" 保存训练笔记失败: {str(e)}
\n
"
)
# 确定最终的模型文件路径
model_filename
=
os
.
path
.
basename
(
model_path
)
final_model_path
=
os
.
path
.
join
(
target_weights_dir
,
model_filename
)
# 查找实际的模型文件(应该是.dat格式)
final_model_path
=
None
if
os
.
path
.
exists
(
weights_dir
):
for
filename
in
os
.
listdir
(
weights_dir
):
if
filename
.
startswith
(
'best.'
)
and
filename
.
endswith
(
'.dat'
):
final_model_path
=
os
.
path
.
join
(
weights_dir
,
filename
)
break
if
os
.
path
.
exists
(
final_model_path
):
self
.
_appendLog
(
f
"✅ 模型已成功移动到detection_model/{next_id}/weights/{model_filename}
\n
"
)
# 如果没找到best.dat,查找任何.dat文件
if
not
final_model_path
:
for
filename
in
os
.
listdir
(
weights_dir
):
if
filename
.
endswith
(
'.dat'
):
final_model_path
=
os
.
path
.
join
(
weights_dir
,
filename
)
break
if
final_model_path
and
os
.
path
.
exists
(
final_model_path
):
self
.
_appendLog
(
f
"✅ 模型已保存在detection_model/{model_id}/weights/{os.path.basename(final_model_path)}
\n
"
)
if
training_notes
:
self
.
_appendLog
(
f
"✅ 训练笔记已保存到detection_model/{next_id}/training_notes.txt
\n
"
)
self
.
_appendLog
(
f
"✅ 训练笔记已保存到detection_model/{model_id}/training_notes.txt
\n
"
)
# 自动生成模型描述文件
self
.
_generateModelDescription
(
target_model_dir
,
model_id
,
final_model_path
,
model_name
)
return
final_model_path
else
:
self
.
_appendLog
(
f
"❌ 模型移动失败,目标文件不存在: {final_model_path}
\n
"
)
self
.
_appendLog
(
f
"❌ 未找到模型文件在: {weights_dir}
\n
"
)
return
None
else
:
self
.
_appendLog
(
f
"❌ weights目录不存在: {weights_dir}
\n
"
)
return
None
except
Exception
as
e
:
self
.
_appendLog
(
f
"❌ [ERROR]
移动模型到detection_model
失败: {str(e)}
\n
"
)
self
.
_appendLog
(
f
"❌ [ERROR]
处理detection_model目录中的模型
失败: {str(e)}
\n
"
)
import
traceback
traceback
.
print_exc
()
return
None
...
...
@@ -3037,7 +3024,7 @@ class ModelTrainingHandler(ModelTestHandler):
traceback
.
print_exc
()
def
_refreshModelTestPage
(
self
):
"""刷新模型测试页面
的模型列表
"""
"""刷新模型测试页面"""
try
:
# 获取主窗口的模型测试页面
if
hasattr
(
self
,
'main_window'
)
and
hasattr
(
self
.
main_window
,
'testModelPage'
):
...
...
@@ -3046,9 +3033,8 @@ class ModelTrainingHandler(ModelTestHandler):
if
hasattr
(
self
.
main_window
,
'modelSetPage'
):
self
.
main_window
.
testModelPage
.
loadModelsFromModelSetPage
(
self
.
main_window
.
modelSetPage
)
else
:
# 否则直接刷新模型测试页面
if
hasattr
(
self
.
main_window
.
testModelPage
,
'loadModelsFromConfig'
):
self
.
main_window
.
testModelPage
.
loadModelsFromConfig
()
# 备用方案:直接加载模型
self
.
main_window
.
testModelPage
.
loadModels
()
print
(
"[信息] 模型测试页面刷新完成"
)
else
:
print
(
"[警告] 无法访问模型测试页面"
)
...
...
@@ -3057,6 +3043,60 @@ class ModelTrainingHandler(ModelTestHandler):
import
traceback
traceback
.
print_exc
()
def
_syncAllModelPages
(
self
):
"""同步所有模型相关页面"""
try
:
self
.
_appendLog
(
"
\n
正在同步所有模型页面...
\n
"
)
# 1. 刷新模型集管理页面
try
:
self
.
_refreshModelSetPage
()
self
.
_appendLog
(
" ✅ 模型集管理页面已同步
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
" ❌ 模型集管理页面同步失败: {str(e)}
\n
"
)
# 2. 刷新模型测试页面
try
:
self
.
_refreshModelTestPage
()
self
.
_appendLog
(
" ✅ 模型测试页面已同步
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
" ❌ 模型测试页面同步失败: {str(e)}
\n
"
)
# 3. 刷新训练页面的模型下拉菜单
try
:
self
.
_refreshTrainingPageModels
()
self
.
_appendLog
(
" ✅ 训练页面模型列表已同步
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
" ❌ 训练页面模型列表同步失败: {str(e)}
\n
"
)
# 4. 通知主窗口更新相关UI
try
:
if
hasattr
(
self
,
'main_window'
)
and
hasattr
(
self
.
main_window
,
'updateModelRelatedUI'
):
self
.
main_window
.
updateModelRelatedUI
()
self
.
_appendLog
(
" ✅ 主窗口UI已同步
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
" ❌ 主窗口UI同步失败: {str(e)}
\n
"
)
self
.
_appendLog
(
"页面同步完成
\n
"
)
except
Exception
as
e
:
self
.
_appendLog
(
f
"❌ 页面同步失败: {str(e)}
\n
"
)
import
traceback
traceback
.
print_exc
()
def
_refreshTrainingPageModels
(
self
):
"""刷新训练页面的模型下拉菜单"""
try
:
if
hasattr
(
self
,
'training_panel'
)
and
hasattr
(
self
.
training_panel
,
'refreshBaseModelList'
):
self
.
training_panel
.
refreshBaseModelList
()
elif
hasattr
(
self
,
'main_window'
)
and
hasattr
(
self
.
main_window
,
'trainingPage'
):
if
hasattr
(
self
.
main_window
.
trainingPage
,
'refreshBaseModelList'
):
self
.
main_window
.
trainingPage
.
refreshBaseModelList
()
except
Exception
as
e
:
print
(
f
"[错误] 刷新训练页面模型列表失败: {e}"
)
import
traceback
traceback
.
print_exc
()
def
_getFileSize
(
self
,
file_path
):
"""获取文件大小(格式化字符串)"""
try
:
...
...
@@ -3249,3 +3289,117 @@ class ModelTrainingHandler(ModelTestHandler):
import
traceback
self
.
_appendLog
(
traceback
.
format_exc
())
return
None
def
_generateModelDescription
(
self
,
target_model_dir
,
model_id
,
model_path
,
model_name
):
"""自动生成模型描述文件"""
try
:
import
time
from
pathlib
import
Path
self
.
_appendLog
(
f
"
\n
开始生成模型描述文件...
\n
"
)
# 新结构:查找根目录下的dat模型文件
actual_model_path
=
None
# 优先查找best.dat文件(在根目录)
for
filename
in
os
.
listdir
(
target_model_dir
):
if
filename
.
startswith
(
'best.'
)
and
filename
.
endswith
(
'.dat'
):
actual_model_path
=
os
.
path
.
join
(
target_model_dir
,
filename
)
break
# 如果没找到best.dat,查找任何.dat文件(在根目录)
if
not
actual_model_path
:
for
filename
in
os
.
listdir
(
target_model_dir
):
if
filename
.
endswith
(
'.dat'
):
actual_model_path
=
os
.
path
.
join
(
target_model_dir
,
filename
)
break
# 使用找到的dat文件路径,如果没找到则使用原路径
if
actual_model_path
and
os
.
path
.
exists
(
actual_model_path
):
model_path
=
actual_model_path
self
.
_appendLog
(
f
" 使用模型文件: {os.path.basename(model_path)}
\n
"
)
# 获取模型文件信息
model_file
=
Path
(
model_path
)
if
os
.
path
.
exists
(
model_path
):
file_size_bytes
=
model_file
.
stat
()
.
st_size
file_size_mb
=
round
(
file_size_bytes
/
(
1024
*
1024
),
1
)
mod_time
=
time
.
strftime
(
'
%
Y-
%
m-
%
d
%
H:
%
M:
%
S'
,
time
.
localtime
(
model_file
.
stat
()
.
st_mtime
))
else
:
file_size_mb
=
"未知"
mod_time
=
self
.
_getCurrentTimestamp
()
# 确定模型类型(现在应该都是.dat格式)
model_type
=
".dat"
# 生成模型描述内容
description_content
=
f
"""【模型信息】{model_id}/{model_file.stem}
基础信息
------------------------------
模型名称: {model_id}/{model_file.stem}
模型类型: {model_type}
模型路径: {model_path}
文件大小: {file_size_mb} MB
最后修改: {mod_time}
模型说明
------------------------------
这是一个液位检测专用模型,用于识别和检测液位线位置。
模型采用深度学习技术,能够准确识别各种容器中的液位状态。
本模型通过训练升级生成,具有更好的检测精度和稳定性。
技术特性
------------------------------
- 支持多种液体类型检测
- 高精度液位线定位
- 实时检测能力
- 适用于工业自动化场景
- 经过训练优化的检测算法
通用模型特性
------------------------------
- 基于深度学习技术
- 专门针对液位检测优化
- 支持实时处理
- 适用于工业环境
使用说明
------------------------------
1. 模型已经过训练升级,可直接用于液位检测任务
2. 支持图像输入,输出液位线坐标和置信度
3. 建议在良好光照条件下使用以获得最佳效果
4. 如需进一步提升特定场景的检测效果,可继续进行模型微调训练
注意事项
------------------------------
- 确保输入图像清晰度足够
- 避免强反光或阴影干扰
- 定期验证模型检测精度
- 如发现检测效果下降,建议重新训练或更新模型
- 新训练的模型建议先在测试环境中验证效果
训练信息
------------------------------
训练时间: {self._getCurrentTimestamp()}
模型版本: {model_id}
训练状态: 已完成
备注: 通过系统自动训练升级生成"""
# 保存模型描述文件到training_results目录
training_results_dir
=
os
.
path
.
join
(
target_model_dir
,
'training_results'
)
os
.
makedirs
(
training_results_dir
,
exist_ok
=
True
)
description_file
=
os
.
path
.
join
(
training_results_dir
,
'模型描述.txt'
)
with
open
(
description_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
f
.
write
(
description_content
)
self
.
_appendLog
(
f
"✅ 模型描述文件已生成: {description_file}
\n
"
)
self
.
_appendLog
(
f
" 文件包含完整的模型信息和使用说明
\n
"
)
return
description_file
except
Exception
as
e
:
self
.
_appendLog
(
f
"❌ 生成模型描述文件失败: {str(e)}
\n
"
)
import
traceback
traceback
.
print_exc
()
return
None
handlers/modelpage/model_trainingworker_handler.py
View file @
7e155e88
...
...
@@ -709,26 +709,38 @@ class TrainingWorker(QThread):
if
device_str
.
lower
()
in
[
'cpu'
,
'-1'
]:
workers
=
0
# CPU模式下禁用多线程数据加载
# 获取下一个可用的模型ID并创建目录
# 获取下一个可用的模型ID并创建目录
(直接保存到detection_model)
project_root
=
get_project_root
()
train_model_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'trai
n_model'
)
os
.
makedirs
(
trai
n_model_dir
,
exist_ok
=
True
)
detection_model_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'detectio
n_model'
)
os
.
makedirs
(
detectio
n_model_dir
,
exist_ok
=
True
)
# 查找下一个可用的数字ID
existing_dirs
=
[]
for
item
in
os
.
listdir
(
trai
n_model_dir
):
item_path
=
os
.
path
.
join
(
trai
n_model_dir
,
item
)
for
item
in
os
.
listdir
(
detectio
n_model_dir
):
item_path
=
os
.
path
.
join
(
detectio
n_model_dir
,
item
)
if
os
.
path
.
isdir
(
item_path
)
and
item
.
isdigit
():
existing_dirs
.
append
(
int
(
item
))
next_id
=
max
(
existing_dirs
)
+
1
if
existing_dirs
else
1
model_output_dir
=
os
.
path
.
join
(
train_model_dir
,
str
(
next_id
))
# 预先设置weights_dir(YOLO会在project目录下创建weights子目录)
weights_dir
=
os
.
path
.
join
(
model_output_dir
,
"weights"
)
model_output_dir
=
os
.
path
.
join
(
detection_model_dir
,
str
(
next_id
))
# 创建新的目录结构
os
.
makedirs
(
model_output_dir
,
exist_ok
=
True
)
training_results_dir
=
os
.
path
.
join
(
model_output_dir
,
"training_results"
)
test_results_dir
=
os
.
path
.
join
(
model_output_dir
,
"test_results"
)
os
.
makedirs
(
training_results_dir
,
exist_ok
=
True
)
os
.
makedirs
(
test_results_dir
,
exist_ok
=
True
)
# YOLO训练时使用临时目录,训练完成后移动文件
temp_training_dir
=
os
.
path
.
join
(
model_output_dir
,
"temp_training"
)
weights_dir
=
os
.
path
.
join
(
temp_training_dir
,
"weights"
)
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
training_report
[
"model_output_dir"
]
=
model_output_dir
self
.
training_report
[
"training_results_dir"
]
=
training_results_dir
self
.
log_output
.
emit
(
f
"模型将直接保存到: {model_output_dir}
\n
"
)
self
.
log_output
.
emit
(
f
"模型将保存到: {model_output_dir}
\n
"
)
self
.
log_output
.
emit
(
f
"训练结果将保存到: {training_results_dir}
\n
"
)
self
.
log_output
.
emit
(
f
"测试结果将保存到: {test_results_dir}
\n
"
)
# 开始训练
try
:
...
...
@@ -742,7 +754,7 @@ class TrainingWorker(QThread):
optimizer
=
self
.
training_params
[
'optimizer'
],
close_mosaic
=
self
.
training_params
[
'close_mosaic'
],
resume
=
self
.
training_params
[
'resume'
],
project
=
model_output
_dir
,
project
=
temp_training
_dir
,
name
=
''
,
# 空名称,直接使用project目录
single_cls
=
self
.
training_params
[
'single_cls'
],
cache
=
False
,
...
...
@@ -772,36 +784,26 @@ class TrainingWorker(QThread):
if
save_dir
:
save_dir_abs
=
os
.
path
.
abspath
(
str
(
save_dir
))
weights_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
save_dir_abs
,
"weights"
))
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"[调试] 实际保存目录: {save_dir_abs}
\n
"
)
self
.
log_output
.
emit
(
f
"[调试] 实际权重目录: {weights_dir}
\n
"
)
# 立即转换PT文件为DAT格式并删除PT文件
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
# 整理训练结果到新的目录结构
model_output_dir
=
self
.
training_report
.
get
(
"model_output_dir"
)
training_results_dir
=
self
.
training_report
.
get
(
"training_results_dir"
)
if
model_output_dir
and
training_results_dir
:
self
.
_organizeTrainingResults
(
save_dir_abs
,
model_output_dir
,
training_results_dir
)
else
:
# 备用方案:使用预设的
model_output
_dir
# 备用方案:使用预设的
temp_training
_dir
self
.
log_output
.
emit
(
"
\n
[WARNING] 无法从trainer获取保存目录,使用预设目录
\n
"
)
if
'model_output_dir'
in
locals
():
# 查找实际的weights目录(可能在train子目录下)
possible_weights_dirs
=
[
os
.
path
.
join
(
model_output_dir
,
"train"
,
"weights"
),
os
.
path
.
join
(
model_output_dir
,
"weights"
)
]
for
possible_dir
in
possible_weights_dirs
:
if
os
.
path
.
exists
(
possible_dir
):
weights_dir
=
possible_dir
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"[调试] 找到权重目录: {weights_dir}
\n
"
)
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
break
model_output_dir
=
self
.
training_report
.
get
(
"model_output_dir"
)
training_results_dir
=
self
.
training_report
.
get
(
"training_results_dir"
)
if
model_output_dir
and
training_results_dir
and
'temp_training_dir'
in
locals
():
if
os
.
path
.
exists
(
temp_training_dir
):
self
.
_organizeTrainingResults
(
temp_training_dir
,
model_output_dir
,
training_results_dir
)
else
:
self
.
log_output
.
emit
(
f
"[ERROR]
未找到权重目录,跳过转换
\n
"
)
self
.
log_output
.
emit
(
f
"[ERROR]
临时训练目录不存在: {temp_training_dir}
\n
"
)
else
:
self
.
log_output
.
emit
(
f
"[ERROR]
model_output_dir 未定义,跳过转换
\n
"
)
self
.
log_output
.
emit
(
f
"[ERROR]
必要的目录信息未定义,跳过整理
\n
"
)
except
Exception
as
convert_err
:
self
.
log_output
.
emit
(
f
"
\n
[ERROR] 转换过程出错: {convert_err}
\n
"
)
import
traceback
...
...
@@ -840,36 +842,26 @@ class TrainingWorker(QThread):
if
save_dir
:
save_dir_abs
=
os
.
path
.
abspath
(
str
(
save_dir
))
weights_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
save_dir_abs
,
"weights"
))
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"[调试] 实际保存目录: {save_dir_abs}
\n
"
)
self
.
log_output
.
emit
(
f
"[调试] 实际权重目录: {weights_dir}
\n
"
)
# 立即转换PT文件为DAT格式并删除PT文件
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
# 整理训练结果到新的目录结构
model_output_dir
=
self
.
training_report
.
get
(
"model_output_dir"
)
training_results_dir
=
self
.
training_report
.
get
(
"training_results_dir"
)
if
model_output_dir
and
training_results_dir
:
self
.
_organizeTrainingResults
(
save_dir_abs
,
model_output_dir
,
training_results_dir
)
else
:
# 备用方案:使用预设的
model_output
_dir
# 备用方案:使用预设的
temp_training
_dir
self
.
log_output
.
emit
(
"
\n
[WARNING] 无法从trainer获取保存目录,使用预设目录
\n
"
)
if
'model_output_dir'
in
locals
():
# 查找实际的weights目录(可能在train子目录下)
possible_weights_dirs
=
[
os
.
path
.
join
(
model_output_dir
,
"train"
,
"weights"
),
os
.
path
.
join
(
model_output_dir
,
"weights"
)
]
for
possible_dir
in
possible_weights_dirs
:
if
os
.
path
.
exists
(
possible_dir
):
weights_dir
=
possible_dir
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"[调试] 找到权重目录: {weights_dir}
\n
"
)
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
break
model_output_dir
=
self
.
training_report
.
get
(
"model_output_dir"
)
training_results_dir
=
self
.
training_report
.
get
(
"training_results_dir"
)
if
model_output_dir
and
training_results_dir
and
'temp_training_dir'
in
locals
():
if
os
.
path
.
exists
(
temp_training_dir
):
self
.
_organizeTrainingResults
(
temp_training_dir
,
model_output_dir
,
training_results_dir
)
else
:
self
.
log_output
.
emit
(
f
"[ERROR]
未找到权重目录,跳过转换
\n
"
)
self
.
log_output
.
emit
(
f
"[ERROR]
临时训练目录不存在: {temp_training_dir}
\n
"
)
else
:
self
.
log_output
.
emit
(
f
"[ERROR]
model_output_dir 未定义,跳过转换
\n
"
)
self
.
log_output
.
emit
(
f
"[ERROR]
必要的目录信息未定义,跳过整理
\n
"
)
except
Exception
as
convert_err
:
self
.
log_output
.
emit
(
f
"
\n
[ERROR] 转换过程出错: {convert_err}
\n
"
)
import
traceback
...
...
@@ -885,15 +877,14 @@ class TrainingWorker(QThread):
save_dir
=
getattr
(
getattr
(
model
,
"trainer"
,
None
),
"save_dir"
,
None
)
if
save_dir
:
save_dir_abs
=
os
.
path
.
abspath
(
str
(
save_dir
))
weights_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
save_dir_abs
,
"weights"
))
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"
\n
[调试] 实际保存目录: {save_dir_abs}
\n
"
)
self
.
log_output
.
emit
(
f
"[调试] 实际权重目录: {weights_dir}
\n
"
)
# 立即转换PT文件为DAT格式并删除PT文件
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
# 整理训练结果到新的目录结构
model_output_dir
=
self
.
training_report
.
get
(
"model_output_dir"
)
training_results_dir
=
self
.
training_report
.
get
(
"training_results_dir"
)
if
model_output_dir
and
training_results_dir
:
self
.
_organizeTrainingResults
(
save_dir_abs
,
model_output_dir
,
training_results_dir
)
else
:
self
.
log_output
.
emit
(
"
\n
[WARNING] 无法获取模型保存目录,跳过转换
\n
"
)
except
Exception
as
convert_err
:
...
...
@@ -1271,3 +1262,89 @@ class TrainingWorker(QThread):
def
has_training_started
(
self
):
"""检查训练是否已经真正开始"""
return
self
.
training_actually_started
def
_organizeTrainingResults
(
self
,
temp_training_dir
,
model_output_dir
,
training_results_dir
):
"""
整理训练结果到新的目录结构
Args:
temp_training_dir: 临时训练目录
model_output_dir: 模型输出根目录
training_results_dir: 训练结果目录
"""
try
:
import
shutil
self
.
log_output
.
emit
(
"
\n
正在整理训练结果文件...
\n
"
)
# 1. 处理权重文件 - 转换并移动到根目录
weights_dir
=
os
.
path
.
join
(
temp_training_dir
,
"weights"
)
if
os
.
path
.
exists
(
weights_dir
):
self
.
log_output
.
emit
(
"正在转换并移动模型文件...
\n
"
)
# 转换PT文件为DAT格式
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
# 移动DAT文件到根目录
for
filename
in
os
.
listdir
(
weights_dir
):
if
filename
.
endswith
(
'.dat'
):
src_path
=
os
.
path
.
join
(
weights_dir
,
filename
)
dst_path
=
os
.
path
.
join
(
model_output_dir
,
filename
)
shutil
.
move
(
src_path
,
dst_path
)
self
.
log_output
.
emit
(
f
" 移动模型文件: {filename}
\n
"
)
# 2. 移动训练结果文件到training_results目录
training_files
=
[
'results.csv'
,
# 训练结果
'args.yaml'
,
# 训练参数
'config.yaml'
,
# 配置文件
'labels.jpg'
,
# 标签分布图
'labels_correlogram.jpg'
,
# 标签相关图
'train_batch*.jpg'
,
# 训练批次图片
'val_batch*.jpg'
,
# 验证批次图片
'confusion_matrix.png'
,
# 混淆矩阵
'F1_curve.png'
,
# F1曲线
'P_curve.png'
,
# 精确率曲线
'R_curve.png'
,
# 召回率曲线
'PR_curve.png'
,
# PR曲线
'results.png'
# 结果图表
]
# 移动单个文件
for
pattern
in
training_files
:
if
'*'
in
pattern
:
# 处理通配符模式
import
glob
base_pattern
=
pattern
.
replace
(
'*'
,
'*'
)
matches
=
glob
.
glob
(
os
.
path
.
join
(
temp_training_dir
,
base_pattern
))
for
match
in
matches
:
if
os
.
path
.
isfile
(
match
):
filename
=
os
.
path
.
basename
(
match
)
dst_path
=
os
.
path
.
join
(
training_results_dir
,
filename
)
shutil
.
move
(
match
,
dst_path
)
self
.
log_output
.
emit
(
f
" 移动训练文件: {filename}
\n
"
)
else
:
src_path
=
os
.
path
.
join
(
temp_training_dir
,
pattern
)
if
os
.
path
.
exists
(
src_path
):
dst_path
=
os
.
path
.
join
(
training_results_dir
,
pattern
)
shutil
.
move
(
src_path
,
dst_path
)
self
.
log_output
.
emit
(
f
" 移动训练文件: {pattern}
\n
"
)
# 3. 移动plots目录(如果存在)
plots_src
=
os
.
path
.
join
(
temp_training_dir
,
"plots"
)
if
os
.
path
.
exists
(
plots_src
):
plots_dst
=
os
.
path
.
join
(
training_results_dir
,
"plots"
)
shutil
.
move
(
plots_src
,
plots_dst
)
self
.
log_output
.
emit
(
" 移动训练图表目录: plots/
\n
"
)
# 4. 清理临时目录
if
os
.
path
.
exists
(
temp_training_dir
):
shutil
.
rmtree
(
temp_training_dir
)
self
.
log_output
.
emit
(
" 清理临时目录
\n
"
)
self
.
log_output
.
emit
(
"训练结果整理完成!
\n
"
)
except
Exception
as
e
:
self
.
log_output
.
emit
(
f
"[错误] 整理训练结果失败: {e}
\n
"
)
import
traceback
traceback
.
print_exc
()
widgets/modelpage/modelset_page.py
View file @
7e155e88
...
...
@@ -120,7 +120,7 @@ class ModelSetPage(QtWidgets.QWidget):
# 标题
title
=
QtWidgets
.
QLabel
(
"模型集管理"
)
title
.
setStyleSheet
(
"font-size: 13pt; font-weight: bold;"
)
FontManager
.
applyToWidget
(
title
,
weight
=
FontManager
.
WEIGHT_BOLD
)
# 使用全局默认字体,加粗
header_layout
.
addWidget
(
title
)
header_layout
.
addStretch
()
...
...
@@ -246,10 +246,26 @@ class ModelSetPage(QtWidgets.QWidget):
content_layout
.
addWidget
(
left_widget
)
# ========== 中间主内容区(左右分栏:模型列表 + 文本显示) ==========
center_widget
=
QtWidgets
.
QWidget
()
center_main_layout
=
QtWidgets
.
QHBoxLayout
(
center_widget
)
center_main_layout
.
setContentsMargins
(
0
,
0
,
0
,
0
)
center_main_layout
.
setSpacing
(
8
)
# 使用QSplitter实现可缩放的分隔符
center_splitter
=
QtWidgets
.
QSplitter
(
Qt
.
Horizontal
)
center_splitter
.
setChildrenCollapsible
(
False
)
# 防止面板被完全折叠
# 设置splitter样式
center_splitter
.
setStyleSheet
(
"""
QSplitter::handle {
background-color: #e0e0e0;
border: 1px solid #c0c0c0;
width: 6px;
margin: 2px;
border-radius: 3px;
}
QSplitter::handle:hover {
background-color: #d0d0d0;
}
QSplitter::handle:pressed {
background-color: #b0b0b0;
}
"""
)
# ===== 左侧:模型列表区域 =====
left_list_widget
=
QtWidgets
.
QWidget
()
...
...
@@ -338,17 +354,22 @@ class ModelSetPage(QtWidgets.QWidget):
"""
)
self
.
model_info_text
.
setPlaceholderText
(
"选择模型后将显示模型文件夹内的txt文件内容..."
)
# 应用全局字体管理
FontManager
.
applyToWidget
(
self
.
model_info_text
,
size
=
10
)
FontManager
.
applyToWidget
(
self
.
model_info_text
)
# 使用全局默认字体
right_text_layout
.
addWidget
(
self
.
model_info_text
)
# 将左右两部分添加到中间主布局(1:1比例)
center_main_layout
.
addWidget
(
left_list_widget
,
1
)
center_main_layout
.
addWidget
(
right_text_widget
,
1
)
# 将左右两部分添加到splitter
center_splitter
.
addWidget
(
left_list_widget
)
center_splitter
.
addWidget
(
right_text_widget
)
# 设置初始比例(1:1)
center_splitter
.
setSizes
([
400
,
400
])
center_splitter
.
setStretchFactor
(
0
,
1
)
# 左侧可伸缩
center_splitter
.
setStretchFactor
(
1
,
1
)
# 右侧可伸缩
# 🔥 设置中间区域的宽度(主要内容区)- 响应式布局
center_
widget
.
setMinimumWidth
(
scale_w
(
800
))
center_
splitter
.
setMinimumWidth
(
scale_w
(
800
))
content_layout
.
addWidget
(
center_
widget
,
3
)
content_layout
.
addWidget
(
center_
splitter
,
3
)
main_layout
.
addLayout
(
content_layout
)
...
...
@@ -558,7 +579,7 @@ class ModelSetPage(QtWidgets.QWidget):
if
'预训练模型'
not
in
model_type
and
'YOLO'
not
in
model_type
:
model_type
=
"PyTorch (YOLOv8)"
elif
model_file
.
endswith
(
'.dat'
):
model_type
=
"
加密模型 (.dat)
"
model_type
=
"
.dat
"
elif
model_file
.
endswith
(
'.onnx'
):
model_type
=
"ONNX"
...
...
@@ -1216,7 +1237,6 @@ class ModelSetPage(QtWidgets.QWidget):
border-radius: 4px;
}
QLabel {
font-size: 10pt;
padding: 3px;
}
"""
)
...
...
@@ -1294,6 +1314,9 @@ class ModelSetPage(QtWidgets.QWidget):
layout
.
addLayout
(
button_layout
)
# 应用全局字体管理器到对话框及其所有子控件
FontManager
.
applyToWidgetRecursive
(
dialog
)
# 显示对话框
dialog
.
exec_
()
...
...
@@ -1381,8 +1404,8 @@ class ModelSetPage(QtWidgets.QWidget):
# 2. 从配置文件提取通道模型
channel_models
=
self
.
_extractChannelModels
(
config
)
# 3. 扫描模型目录
scanned_models
=
self
.
_scanModelDirectory
()
# 3. 扫描模型目录
(使用统一的getDetectionModels方法确保名称一致)
scanned_models
=
self
.
getDetectionModels
()
# 4. 合并所有模型信息
all_models
=
self
.
_mergeModelInfo
(
channel_models
,
scanned_models
)
...
...
@@ -1469,8 +1492,11 @@ class ModelSetPage(QtWidgets.QWidget):
# 3. 检查模型路径是否存在
if
model_path
and
os
.
path
.
exists
(
model_path
):
# 使用统一的命名逻辑:从模型路径推导出标准名称
model_name
=
self
.
_getModelNameFromPath
(
model_path
)
models
.
append
({
'name'
:
f
"{channel_name}模型"
,
'name'
:
model_name
,
# 使用统一的模型名称
'path'
:
model_path
,
'channel'
:
channel_key
,
'channel_name'
:
channel_name
,
...
...
@@ -1480,19 +1506,64 @@ class ModelSetPage(QtWidgets.QWidget):
return
models
def
_getModelNameFromPath
(
self
,
model_path
):
"""从模型路径推导出统一的模型名称"""
try
:
from
pathlib
import
Path
path
=
Path
(
model_path
)
# 检查是否在detection_model目录下
if
'detection_model'
in
path
.
parts
:
# 找到detection_model目录的索引
parts
=
path
.
parts
detection_index
=
-
1
for
i
,
part
in
enumerate
(
parts
):
if
part
==
'detection_model'
:
detection_index
=
i
break
if
detection_index
>=
0
and
detection_index
+
1
<
len
(
parts
):
# 获取模型ID目录名
model_id
=
parts
[
detection_index
+
1
]
# 尝试读取config.yaml获取模型名称
model_dir
=
path
.
parent
config_locations
=
[
model_dir
/
"training_results"
/
"config.yaml"
,
model_dir
/
"config.yaml"
]
for
config_file
in
config_locations
:
if
config_file
.
exists
():
try
:
import
yaml
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
config_data
=
yaml
.
safe_load
(
f
)
return
config_data
.
get
(
'model_name'
,
f
"模型_{model_id}"
)
except
:
continue
# 如果没有配置文件,使用默认格式
return
f
"模型_{model_id}"
# 如果不在detection_model目录下,使用文件名
return
path
.
stem
except
Exception
:
# 出错时返回文件名
return
Path
(
model_path
)
.
stem
def
_scanModelDirectory
(
self
):
"""扫描模型目录获取所有模型文件(
优先detection_model
)"""
"""扫描模型目录获取所有模型文件(
只从detection_model加载
)"""
models
=
[]
try
:
# 获取模型目录路径
current_dir
=
Path
(
__file__
)
.
parent
.
parent
.
parent
#
扫描多个模型目录(优先detection_model)
#
只扫描detection_model目录,确保数据一致性
model_dirs
=
[
(
current_dir
/
"database"
/
"model"
/
"detection_model"
,
"检测模型"
,
True
),
# 优先级最高
(
current_dir
/
"database"
/
"model"
/
"train_model"
,
"训练模型"
,
False
),
(
current_dir
/
"database"
/
"model"
/
"test_model"
,
"测试模型"
,
False
)
(
current_dir
/
"database"
/
"model"
/
"detection_model"
,
"检测模型"
,
True
),
# 唯一数据源
]
for
model_dir
,
dir_type
,
is_primary
in
model_dirs
:
...
...
@@ -1604,6 +1675,104 @@ class ModelSetPage(QtWidgets.QWidget):
return
models
@staticmethod
def
getDetectionModels
():
"""静态方法:获取detection_model目录下的所有模型,供其他页面使用"""
models
=
[]
try
:
# 获取模型目录路径
current_dir
=
Path
(
__file__
)
.
parent
.
parent
.
parent
detection_model_dir
=
current_dir
/
"database"
/
"model"
/
"detection_model"
if
not
detection_model_dir
.
exists
():
return
models
# 遍历所有子目录,包含数字和非数字目录
all_subdirs
=
[
d
for
d
in
detection_model_dir
.
iterdir
()
if
d
.
is_dir
()]
# 分离数字目录和非数字目录
digit_subdirs
=
[
d
for
d
in
all_subdirs
if
d
.
name
.
isdigit
()]
non_digit_subdirs
=
[
d
for
d
in
all_subdirs
if
not
d
.
name
.
isdigit
()]
# 数字目录按数字降序排列,非数字目录按字母排序
sorted_digit_subdirs
=
sorted
(
digit_subdirs
,
key
=
lambda
x
:
int
(
x
.
name
),
reverse
=
True
)
sorted_non_digit_subdirs
=
sorted
(
non_digit_subdirs
,
key
=
lambda
x
:
x
.
name
)
# 合并:数字目录在前(最新的在前),非数字目录在后
sorted_subdirs
=
sorted_digit_subdirs
+
sorted_non_digit_subdirs
for
subdir
in
sorted_subdirs
:
# 新结构:.dat文件直接在根目录下
# 查找.dat文件(优先在根目录,兼容旧的weights子目录)
dat_files
=
list
(
subdir
.
glob
(
"*.dat"
))
# 如果根目录没有.dat文件,检查weights子目录(兼容旧结构)
if
not
dat_files
:
weights_dir
=
subdir
/
"weights"
if
weights_dir
.
exists
():
dat_files
=
list
(
weights_dir
.
glob
(
"*.dat"
))
if
not
dat_files
:
continue
# 优先选择best.dat,然后是其他.dat文件
model_file
=
None
for
dat_file
in
dat_files
:
if
dat_file
.
name
.
startswith
(
'best.'
):
model_file
=
dat_file
break
if
not
model_file
:
model_file
=
dat_files
[
0
]
# 尝试读取config.yaml获取模型名称
# 新结构:config.yaml在training_results子目录中
model_name
=
None
config_locations
=
[
subdir
/
"training_results"
/
"config.yaml"
,
# 新结构
subdir
/
"config.yaml"
# 兼容旧结构
]
for
config_file
in
config_locations
:
if
config_file
.
exists
():
try
:
import
yaml
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
config_data
=
yaml
.
safe_load
(
f
)
model_name
=
config_data
.
get
(
'model_name'
,
f
"模型_{subdir.name}"
)
break
except
:
continue
if
not
model_name
:
model_name
=
f
"模型_{subdir.name}"
# 获取文件大小
file_size
=
"未知"
try
:
size_bytes
=
model_file
.
stat
()
.
st_size
if
size_bytes
<
1024
*
1024
:
file_size
=
f
"{size_bytes / 1024:.1f} KB"
else
:
file_size
=
f
"{size_bytes / (1024 * 1024):.1f} MB"
except
:
pass
models
.
append
({
'name'
:
model_name
,
'path'
:
str
(
model_file
),
'size'
:
file_size
,
'type'
:
'.dat'
,
'source'
:
'检测模型'
,
'model_id'
:
subdir
.
name
,
'is_primary'
:
True
})
except
Exception
as
e
:
print
(
f
"[错误] 获取detection_model模型失败: {e}"
)
return
models
def
_mergeModelInfo
(
self
,
channel_models
,
scanned_models
):
"""合并模型信息,避免重复(优先detection_model)"""
all_models
=
[]
...
...
@@ -1669,7 +1838,7 @@ class ModelSetPage(QtWidgets.QWidget):
if
model_path
.
endswith
(
'.pt'
):
model_type
=
"PyTorch (YOLOv5/v8)"
elif
model_path
.
endswith
(
'.dat'
):
model_type
=
"
加密模型 (.dat)
"
model_type
=
"
.dat
"
else
:
model_type
=
"未知格式"
...
...
@@ -1793,7 +1962,7 @@ class ModelSetPage(QtWidgets.QWidget):
return
False
def
_loadModelTxtFiles
(
self
,
model_name
):
"""
读取并显示模型文件夹内的txt文件
"""
"""
显示模型最近一次升级的参数和训练效果信息
"""
try
:
# 清空文本显示区域
self
.
model_info_text
.
clear
()
...
...
@@ -1817,38 +1986,139 @@ class ModelSetPage(QtWidgets.QWidget):
self
.
model_info_text
.
setPlainText
(
f
"模型目录不存在:
\n
{model_dir}"
)
return
#
查找目录中的所有txt文件
txt_files
=
list
(
model_dir
.
glob
(
"*.txt"
))
#
构建模型训练信息
content_parts
=
[]
if
not
txt_files
:
self
.
model_info_text
.
setPlainText
(
f
"模型目录中没有找到txt文件:
\n
{model_dir}"
)
return
# 首先检查是否有txt文件,如果有则优先显示txt文件内容
txt_files
=
[]
# 读取并显示所有txt文件的内容
content_parts
=
[]
content_parts
.
append
(
f
"模型目录: {model_dir}
\n
"
)
content_parts
.
append
(
"="
*
60
+
"
\n\n
"
)
# 新结构:搜索training_results目录的txt文件
training_results_dir
=
model_dir
/
"training_results"
if
training_results_dir
.
exists
():
training_results_files
=
list
(
training_results_dir
.
glob
(
"*.txt"
))
txt_files
.
extend
(
training_results_files
)
for
txt_file
in
sorted
(
txt_files
):
content_parts
.
append
(
f
"【文件: {txt_file.name}】
\n
"
)
content_parts
.
append
(
"-"
*
60
+
"
\n
"
)
# 兼容旧结构:搜索模型同级目录的txt文件
current_dir_files
=
list
(
model_dir
.
glob
(
"*.txt"
)
)
txt_files
.
extend
(
current_dir_files
)
# 兼容旧结构:搜索上一级目录的txt文件
parent_dir
=
model_dir
.
parent
if
parent_dir
.
exists
():
parent_dir_files
=
list
(
parent_dir
.
glob
(
"*.txt"
))
txt_files
.
extend
(
parent_dir_files
)
# 如果有txt文件,直接显示txt文件内容(优先显示模型描述文件)
if
txt_files
:
# 优先显示模型描述文件
description_files
=
[
f
for
f
in
txt_files
if
'模型描述'
in
f
.
name
]
other_files
=
[
f
for
f
in
txt_files
if
'模型描述'
not
in
f
.
name
]
# 按优先级排序:模型描述文件在前
sorted_files
=
description_files
+
sorted
(
other_files
)
for
txt_file
in
sorted_files
:
try
:
# 尝试使用UTF-8编码读取
with
open
(
txt_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
file_content
=
f
.
read
()
except
UnicodeDecodeError
:
# 如果UTF-8失败,尝试GBK编码
file_content
=
f
.
read
()
.
strip
()
if
file_content
:
if
txt_file
.
parent
==
model_dir
:
file_label
=
txt_file
.
name
else
:
file_label
=
f
"{txt_file.parent.name}/{txt_file.name}"
# 如果是模型描述文件,直接显示内容,不加标题
if
'模型描述'
in
txt_file
.
name
:
content_parts
.
append
(
f
"{file_content}
\n\n
"
)
else
:
content_parts
.
append
(
f
"=== {file_label} ===
\n
"
)
content_parts
.
append
(
f
"{file_content}
\n\n
"
)
except
:
try
:
with
open
(
txt_file
,
'r'
,
encoding
=
'gbk'
)
as
f
:
file_content
=
f
.
read
()
except
Exception
as
e
:
file_content
=
f
"无法读取文件(编码错误): {str(e)}"
except
Exception
as
e
:
file_content
=
f
"读取文件时出错: {str(e)}"
file_content
=
f
.
read
()
.
strip
()
if
file_content
:
if
txt_file
.
parent
==
model_dir
:
file_label
=
txt_file
.
name
else
:
file_label
=
f
"{txt_file.parent.name}/{txt_file.name}"
content_parts
.
append
(
file_content
)
content_parts
.
append
(
"
\n\n
"
+
"="
*
60
+
"
\n\n
"
)
# 如果是模型描述文件,直接显示内容,不加标题
if
'模型描述'
in
txt_file
.
name
:
content_parts
.
append
(
f
"{file_content}
\n\n
"
)
else
:
content_parts
.
append
(
f
"=== {file_label} ===
\n
"
)
content_parts
.
append
(
f
"{file_content}
\n\n
"
)
except
:
pass
# 如果找到了txt文件,直接显示,不再添加其他自动生成的信息
full_content
=
""
.
join
(
content_parts
)
self
.
model_info_text
.
setPlainText
(
full_content
)
# 滚动到顶部
cursor
=
self
.
model_info_text
.
textCursor
()
cursor
.
movePosition
(
QtGui
.
QTextCursor
.
Start
)
self
.
model_info_text
.
setTextCursor
(
cursor
)
return
else
:
# 如果没有txt文件,则显示基础信息
content_parts
.
append
(
f
"【模型升级信息】 - {model_name}
\n
"
)
content_parts
.
append
(
"="
*
60
+
"
\n\n
"
)
content_parts
.
append
(
"基础信息
\n
"
)
content_parts
.
append
(
"-"
*
30
+
"
\n
"
)
content_parts
.
append
(
f
"模型名称: {model_name}
\n
"
)
content_parts
.
append
(
f
"模型类型: {model_info.get('type', '未知')}
\n
"
)
content_parts
.
append
(
f
"模型路径: {model_path}
\n
"
)
content_parts
.
append
(
f
"文件大小: {model_info.get('size', '未知')}
\n
"
)
# 获取文件修改时间
try
:
if
os
.
path
.
exists
(
model_path
):
import
time
mtime
=
os
.
path
.
getmtime
(
model_path
)
mod_time
=
time
.
strftime
(
'
%
Y-
%
m-
%
d
%
H:
%
M:
%
S'
,
time
.
localtime
(
mtime
))
content_parts
.
append
(
f
"最后修改: {mod_time}
\n
"
)
except
:
pass
content_parts
.
append
(
"
\n
"
)
# 2. 训练参数信息
training_info
=
self
.
_getTrainingParameters
(
model_dir
)
if
training_info
:
content_parts
.
append
(
"训练参数
\n
"
)
content_parts
.
append
(
"-"
*
30
+
"
\n
"
)
for
key
,
value
in
training_info
.
items
():
content_parts
.
append
(
f
"{key}: {value}
\n
"
)
content_parts
.
append
(
"
\n
"
)
# 3. 训练效果信息
results_info
=
self
.
_getTrainingResults
(
model_dir
)
if
results_info
:
content_parts
.
append
(
"训练效果
\n
"
)
content_parts
.
append
(
"-"
*
30
+
"
\n
"
)
for
key
,
value
in
results_info
.
items
():
content_parts
.
append
(
f
"{key}: {value}
\n
"
)
content_parts
.
append
(
"
\n
"
)
# 4. 模型评估
evaluation
=
self
.
_evaluateModelPerformance
(
results_info
)
if
evaluation
:
content_parts
.
append
(
"效果评估
\n
"
)
content_parts
.
append
(
"-"
*
30
+
"
\n
"
)
content_parts
.
append
(
f
"{evaluation}
\n\n
"
)
# 如果没有找到训练信息,显示基础信息
if
not
training_info
and
not
results_info
:
content_parts
.
append
(
"说明
\n
"
)
content_parts
.
append
(
"-"
*
30
+
"
\n
"
)
content_parts
.
append
(
"未找到详细的训练记录文件。
\n
"
)
content_parts
.
append
(
"这可能是预训练模型或外部导入的模型。
\n\n
"
)
content_parts
.
append
(
"模型仍可正常使用,如需查看训练效果,
\n
"
)
content_parts
.
append
(
"建议重新进行模型升级训练。
\n
"
)
# 显示所有内容
full_content
=
""
.
join
(
content_parts
)
...
...
@@ -1860,7 +2130,199 @@ class ModelSetPage(QtWidgets.QWidget):
self
.
model_info_text
.
setTextCursor
(
cursor
)
except
Exception
as
e
:
self
.
model_info_text
.
setPlainText
(
f
"读取txt文件时出错:
\n
{str(e)}"
)
self
.
model_info_text
.
setPlainText
(
f
"读取模型信息时出错:
\n
{str(e)}"
)
def
_getTrainingParameters
(
self
,
model_dir
):
"""获取训练参数信息"""
training_params
=
{}
try
:
# 查找 args.yaml 文件(YOLO训练参数)
args_file
=
model_dir
/
"args.yaml"
if
args_file
.
exists
():
with
open
(
args_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
args_data
=
yaml
.
safe_load
(
f
)
if
args_data
:
training_params
[
"训练轮数"
]
=
args_data
.
get
(
'epochs'
,
'未知'
)
training_params
[
"批次大小"
]
=
args_data
.
get
(
'batch'
,
'未知'
)
training_params
[
"图像尺寸"
]
=
args_data
.
get
(
'imgsz'
,
'未知'
)
training_params
[
"学习率"
]
=
args_data
.
get
(
'lr0'
,
'未知'
)
training_params
[
"优化器"
]
=
args_data
.
get
(
'optimizer'
,
'未知'
)
training_params
[
"数据集"
]
=
args_data
.
get
(
'data'
,
'未知'
)
training_params
[
"设备"
]
=
args_data
.
get
(
'device'
,
'未知'
)
training_params
[
"工作进程"
]
=
args_data
.
get
(
'workers'
,
'未知'
)
# 查找其他可能的配置文件
config_files
=
list
(
model_dir
.
glob
(
"*config*.yaml"
))
+
list
(
model_dir
.
glob
(
"*config*.yml"
))
for
config_file
in
config_files
:
if
config_file
.
name
!=
"args.yaml"
:
try
:
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
config_data
=
yaml
.
safe_load
(
f
)
if
config_data
and
isinstance
(
config_data
,
dict
):
# 提取一些关键参数
if
'epochs'
in
config_data
:
training_params
[
"训练轮数"
]
=
config_data
[
'epochs'
]
if
'batch_size'
in
config_data
:
training_params
[
"批次大小"
]
=
config_data
[
'batch_size'
]
except
:
continue
except
Exception
as
e
:
pass
return
training_params
def
_getTrainingResults
(
self
,
model_dir
):
"""获取训练结果信息"""
results
=
{}
try
:
# 查找 results.csv 文件(YOLO训练结果)
results_file
=
model_dir
/
"results.csv"
if
results_file
.
exists
():
import
pandas
as
pd
try
:
df
=
pd
.
read_csv
(
results_file
)
if
not
df
.
empty
:
# 获取最后一行的结果(最终训练结果)
last_row
=
df
.
iloc
[
-
1
]
# 提取关键指标
if
'metrics/precision(B)'
in
df
.
columns
:
results
[
"精确率"
]
=
f
"{last_row['metrics/precision(B)']:.4f}"
elif
'precision'
in
df
.
columns
:
results
[
"精确率"
]
=
f
"{last_row['precision']:.4f}"
if
'metrics/recall(B)'
in
df
.
columns
:
results
[
"召回率"
]
=
f
"{last_row['metrics/recall(B)']:.4f}"
elif
'recall'
in
df
.
columns
:
results
[
"召回率"
]
=
f
"{last_row['recall']:.4f}"
if
'metrics/mAP50(B)'
in
df
.
columns
:
results
[
"mAP@0.5"
]
=
f
"{last_row['metrics/mAP50(B)']:.4f}"
elif
'mAP_0.5'
in
df
.
columns
:
results
[
"mAP@0.5"
]
=
f
"{last_row['mAP_0.5']:.4f}"
if
'metrics/mAP50-95(B)'
in
df
.
columns
:
results
[
"mAP@0.5:0.95"
]
=
f
"{last_row['metrics/mAP50-95(B)']:.4f}"
elif
'mAP_0.5:0.95'
in
df
.
columns
:
results
[
"mAP@0.5:0.95"
]
=
f
"{last_row['mAP_0.5:0.95']:.4f}"
if
'train/box_loss'
in
df
.
columns
:
results
[
"训练损失"
]
=
f
"{last_row['train/box_loss']:.4f}"
elif
'train_loss'
in
df
.
columns
:
results
[
"训练损失"
]
=
f
"{last_row['train_loss']:.4f}"
if
'val/box_loss'
in
df
.
columns
:
results
[
"验证损失"
]
=
f
"{last_row['val/box_loss']:.4f}"
elif
'val_loss'
in
df
.
columns
:
results
[
"验证损失"
]
=
f
"{last_row['val_loss']:.4f}"
# 训练轮数
if
'epoch'
in
df
.
columns
:
results
[
"完成轮数"
]
=
f
"{int(last_row['epoch']) + 1}"
except
ImportError
:
# 如果没有pandas,尝试手动解析CSV
with
open
(
results_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
lines
=
f
.
readlines
()
if
len
(
lines
)
>
1
:
# 有标题行和数据行
headers
=
lines
[
0
]
.
strip
()
.
split
(
','
)
last_data
=
lines
[
-
1
]
.
strip
()
.
split
(
','
)
for
i
,
header
in
enumerate
(
headers
):
if
i
<
len
(
last_data
):
if
'precision'
in
header
.
lower
():
results
[
"精确率"
]
=
last_data
[
i
]
elif
'recall'
in
header
.
lower
():
results
[
"召回率"
]
=
last_data
[
i
]
elif
'map50'
in
header
.
lower
():
results
[
"mAP@0.5"
]
=
last_data
[
i
]
except
Exception
as
e
:
pass
# 查找其他结果文件
log_files
=
list
(
model_dir
.
glob
(
"*.log"
))
+
list
(
model_dir
.
glob
(
"*train*.txt"
))
for
log_file
in
log_files
:
try
:
with
open
(
log_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
content
=
f
.
read
()
# 简单提取一些关键信息
if
'Best mAP'
in
content
or
'best map'
in
content
.
lower
():
lines
=
content
.
split
(
'
\n
'
)
for
line
in
lines
:
if
'map'
in
line
.
lower
()
and
(
'best'
in
line
.
lower
()
or
'final'
in
line
.
lower
()):
results
[
"最佳mAP"
]
=
line
.
strip
()
break
except
:
continue
except
Exception
as
e
:
pass
return
results
def
_evaluateModelPerformance
(
self
,
results_info
):
"""评估模型性能"""
if
not
results_info
:
return
"暂无训练结果数据,无法评估模型性能。"
evaluation_parts
=
[]
try
:
# 评估mAP指标
if
"mAP@0.5"
in
results_info
:
map50
=
float
(
results_info
[
"mAP@0.5"
])
if
map50
>=
0.9
:
evaluation_parts
.
append
(
"检测精度: 优秀 (mAP@0.5 ≥ 0.9)"
)
elif
map50
>=
0.8
:
evaluation_parts
.
append
(
"检测精度: 良好 (mAP@0.5 ≥ 0.8)"
)
elif
map50
>=
0.7
:
evaluation_parts
.
append
(
"检测精度: 一般 (mAP@0.5 ≥ 0.7)"
)
else
:
evaluation_parts
.
append
(
"检测精度: 较差 (mAP@0.5 < 0.7)"
)
# 评估精确率和召回率
if
"精确率"
in
results_info
and
"召回率"
in
results_info
:
precision
=
float
(
results_info
[
"精确率"
])
recall
=
float
(
results_info
[
"召回率"
])
if
precision
>=
0.85
and
recall
>=
0.85
:
evaluation_parts
.
append
(
"平衡性能: 优秀 (精确率和召回率均 ≥ 0.85)"
)
elif
precision
>=
0.75
and
recall
>=
0.75
:
evaluation_parts
.
append
(
"平衡性能: 良好 (精确率和召回率均 ≥ 0.75)"
)
else
:
evaluation_parts
.
append
(
"平衡性能: 需要改进"
)
# 评估损失值
if
"训练损失"
in
results_info
and
"验证损失"
in
results_info
:
train_loss
=
float
(
results_info
[
"训练损失"
])
val_loss
=
float
(
results_info
[
"验证损失"
])
if
abs
(
train_loss
-
val_loss
)
<
0.1
:
evaluation_parts
.
append
(
"模型稳定性: 良好 (训练和验证损失接近)"
)
elif
abs
(
train_loss
-
val_loss
)
<
0.2
:
evaluation_parts
.
append
(
"模型稳定性: 一般"
)
else
:
evaluation_parts
.
append
(
"模型稳定性: 可能存在过拟合"
)
# 综合评估
if
len
(
evaluation_parts
)
==
0
:
return
"数据不足,无法进行详细评估。"
# 添加使用建议
if
any
(
"优秀"
in
part
for
part
in
evaluation_parts
):
evaluation_parts
.
append
(
"
\n
建议: 模型表现良好,可以投入使用。"
)
elif
any
(
"良好"
in
part
for
part
in
evaluation_parts
):
evaluation_parts
.
append
(
"
\n
建议: 模型表现尚可,建议在更多数据上测试。"
)
else
:
evaluation_parts
.
append
(
"
\n
建议: 模型需要进一步训练优化。"
)
except
Exception
as
e
:
return
"评估过程中出现错误,请检查训练结果数据。"
return
"
\n
"
.
join
(
evaluation_parts
)
if
__name__
==
"__main__"
:
...
...
widgets/modelpage/training_page.py
View file @
7e155e88
...
...
@@ -580,22 +580,29 @@ class TrainingPage(QtWidgets.QWidget):
def
_clearCurve
(
self
):
"""清空曲线数据并隐藏曲线面板"""
if
PYQTGRAPH_AVAILABLE
and
hasattr
(
self
,
'curve_plot_widget'
)
:
#
清空数据
try
:
#
总是清空数据(无论PyQtGraph是否可用)
self
.
curve_data_x
=
[]
self
.
curve_data_y
=
[]
print
(
"[曲线] 已清空曲线数据"
)
# 如果PyQtGraph可用,清空图表
if
PYQTGRAPH_AVAILABLE
and
hasattr
(
self
,
'curve_plot_widget'
)
and
self
.
curve_plot_widget
:
# 清空曲线
if
self
.
curve_line
:
if
hasattr
(
self
,
'curve_line'
)
and
self
.
curve_line
:
self
.
curve_plot_widget
.
removeItem
(
self
.
curve_line
)
self
.
curve_line
=
None
print
(
"[曲线] 已清空曲线数据"
)
print
(
"[曲线] 已清空PyQtGraph曲线"
)
# 自动隐藏曲线面板,返回到初始显示状态
self
.
hideCurvePanel
()
print
(
"[曲线] 曲线面板已隐藏,返回初始显示状态"
)
except
Exception
as
e
:
print
(
f
"[曲线] 清空曲线失败: {e}"
)
import
traceback
traceback
.
print_exc
()
def
addCurvePoint
(
self
,
frame_index
,
height_mm
):
"""添加曲线数据点
...
...
@@ -603,16 +610,22 @@ class TrainingPage(QtWidgets.QWidget):
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
if
not
PYQTGRAPH_AVAILABLE
or
not
hasattr
(
self
,
'curve_plot_widget'
):
return
try
:
# 添加数据点
# 确保曲线数据列表存在
if
not
hasattr
(
self
,
'curve_data_x'
):
self
.
curve_data_x
=
[]
if
not
hasattr
(
self
,
'curve_data_y'
):
self
.
curve_data_y
=
[]
# 总是添加数据点到列表中(无论PyQtGraph是否可用)
self
.
curve_data_x
.
append
(
frame_index
)
self
.
curve_data_y
.
append
(
height_mm
)
print
(
f
"[曲线] 添加数据点: 帧{frame_index}, 液位{height_mm:.1f}mm,当前数据点数: {len(self.curve_data_x)}"
)
# 如果PyQtGraph可用且有curve_plot_widget,更新图表
if
PYQTGRAPH_AVAILABLE
and
hasattr
(
self
,
'curve_plot_widget'
)
and
self
.
curve_plot_widget
:
# 如果曲线不存在,创建曲线
if
self
.
curve_line
is
None
:
if
not
hasattr
(
self
,
'curve_line'
)
or
self
.
curve_line
is
None
:
self
.
curve_line
=
self
.
curve_plot_widget
.
plot
(
self
.
curve_data_x
,
self
.
curve_data_y
,
...
...
@@ -625,17 +638,20 @@ class TrainingPage(QtWidgets.QWidget):
except
Exception
as
e
:
print
(
f
"[曲线] 添加数据点失败: {e}"
)
import
traceback
traceback
.
print_exc
()
def
showCurvePanel
(
self
):
"""显示曲线面板"""
if
hasattr
(
self
,
'
display_layou
t'
):
# 切换到曲线面板(索引3:hint, display_panel, video_panel, curve_panel)
self
.
display_layout
.
setCurrentIndex
(
3
)
if
hasattr
(
self
,
'
curve_panel'
)
and
hasattr
(
self
,
'stacked_widge
t'
):
self
.
stacked_widget
.
setCurrentWidget
(
self
.
curve_panel
)
print
(
"[曲线] 显示曲线面板"
)
def
hideCurvePanel
(
self
):
"""隐藏曲线面板,返回到显示面板"""
if
hasattr
(
self
,
'display_layout'
):
self
.
display_layout
.
setCurrentIndex
(
1
)
# 显示 display_panel
"""隐藏曲线面板,返回到初始显示"""
if
hasattr
(
self
,
'display_panel'
)
and
hasattr
(
self
,
'stacked_widget'
):
self
.
stacked_widget
.
setCurrentWidget
(
self
.
display_panel
)
print
(
"[曲线] 隐藏曲线面板"
)
def
saveCurveData
(
self
,
csv_path
):
"""保存曲线数据为CSV文件
...
...
@@ -1343,95 +1359,15 @@ class TrainingPage(QtWidgets.QWidget):
# self._loadTestFileList() # 🔥 不再需要刷新测试文件列表(改用浏览方式)
def
_loadBaseModelOptions
(
self
):
"""从模型集管理页面加载基础模型选项"""
# 清空现有选项
self
.
base_model_combo
.
clear
()
try
:
# 尝试从父窗口获取模型集页面
if
hasattr
(
self
.
_parent
,
'modelSetPage'
):
model_set_page
=
self
.
_parent
.
modelSetPage
# 获取模型参数字典
if
hasattr
(
model_set_page
,
'_model_params'
):
model_params
=
model_set_page
.
_model_params
if
not
model_params
:
self
.
base_model_combo
.
addItem
(
"未找到模型"
,
None
)
return
# 获取默认模型
default_model
=
None
if
hasattr
(
model_set_page
,
'_current_default_model'
):
default_model
=
model_set_page
.
_current_default_model
# 添加所有模型到下拉框
default_index
=
0
for
idx
,
(
model_name
,
params
)
in
enumerate
(
model_params
.
items
()):
model_path
=
params
.
get
(
'path'
,
''
)
# 构建显示名称
display_name
=
model_name
if
model_name
==
default_model
:
display_name
=
f
"{model_name} (默认)"
default_index
=
idx
# 添加到下拉框,使用模型路径作为数据
self
.
base_model_combo
.
addItem
(
display_name
,
model_path
)
# 设置默认选择
self
.
base_model_combo
.
setCurrentIndex
(
default_index
)
else
:
self
.
base_model_combo
.
addItem
(
"模型集页面未初始化"
,
None
)
else
:
self
.
base_model_combo
.
addItem
(
"未找到模型集页面"
,
None
)
except
Exception
as
e
:
print
(
f
"[基础模型] 加载失败: {e}"
)
import
traceback
traceback
.
print_exc
()
self
.
base_model_combo
.
addItem
(
"加载失败"
,
None
)
"""从detection_model目录加载基础模型选项"""
# 使用统一的模型刷新方法
self
.
refreshModelLists
()
def
_loadTestModelOptions
(
self
):
"""加载测试模型选项(从 train_model 文件夹读取)"""
# 清空现有选项
self
.
test_model_combo
.
clear
()
try
:
from
...database.config
import
get_project_root
project_root
=
get_project_root
()
except
ImportError
as
e
:
# 如果导入失败,使用相对路径
project_root
=
Path
(
__file__
)
.
parent
.
parent
.
parent
# 🔥 修改:只从 train_model 目录扫描模型
all_models
=
self
.
_scanDetectionModelDirectory
(
project_root
)
# 添加到下拉框
if
not
all_models
:
self
.
test_model_combo
.
addItem
(
"未找到测试模型"
)
return
# 获取记忆的测试模型路径
remembered_model
=
self
.
_getRememberedTestModel
(
project_root
)
default_index
=
0
# 添加所有模型到下拉框
for
idx
,
model
in
enumerate
(
all_models
):
display_name
=
model
[
'name'
]
model_path
=
model
[
'path'
]
# 如果找到记忆的模型,设置为默认选择
if
remembered_model
and
model_path
==
remembered_model
:
default_index
=
idx
display_name
=
f
"{display_name} (上次使用)"
self
.
test_model_combo
.
addItem
(
display_name
,
model_path
)
# 设置默认选择
self
.
test_model_combo
.
setCurrentIndex
(
default_index
)
"""从detection_model目录加载测试模型选项"""
# 测试模型和基础模型使用相同的数据源,无需单独加载
# refreshModelLists() 已经处理了测试模型下拉菜单
pass
def
_loadModelsFromConfig
(
self
,
project_root
):
"""从配置文件加载通道模型"""
...
...
@@ -1823,8 +1759,8 @@ class TrainingPage(QtWidgets.QWidget):
self
.
_showNoCurveMessage
()
return
#
切换到曲线面板显示
self
.
showCurvePanel
()
#
在当前测试页面中显示曲线,而不是切换面板
self
.
_showCurveInTestPage
()
# 显示曲线信息提示
data_count
=
len
(
self
.
curve_data_x
)
...
...
@@ -1835,7 +1771,7 @@ class TrainingPage(QtWidgets.QWidget):
self
,
"曲线信息"
,
f
"图片测试结果:
\n
液位高度: {liquid_level:.1f} mm
\n\n
"
f
"曲线已显示在左侧
面板
中。"
f
"曲线已显示在左侧
测试页面
中。"
)
else
:
# 视频测试
...
...
@@ -1849,7 +1785,7 @@ class TrainingPage(QtWidgets.QWidget):
f
"数据点数: {data_count} 个
\n
"
f
"液位范围: {min_level:.1f} - {max_level:.1f} mm
\n
"
f
"平均液位: {avg_level:.1f} mm
\n\n
"
f
"曲线已显示在左侧
面板
中。"
f
"曲线已显示在左侧
测试页面
中。"
)
except
Exception
as
e
:
...
...
@@ -1860,6 +1796,356 @@ class TrainingPage(QtWidgets.QWidget):
f
"显示曲线时发生错误:
\n
{str(e)}"
)
def
_showCurveInTestPage
(
self
):
"""在测试页面中显示曲线"""
try
:
print
(
f
"[曲线显示] 开始显示曲线,PyQtGraph可用: {PYQTGRAPH_AVAILABLE}"
)
# 检查曲线数据
if
not
hasattr
(
self
,
'curve_data_x'
)
or
not
hasattr
(
self
,
'curve_data_y'
):
print
(
"[曲线显示] 错误: 缺少曲线数据属性"
)
self
.
_showCurveAsText
()
return
if
len
(
self
.
curve_data_x
)
==
0
or
len
(
self
.
curve_data_y
)
==
0
:
print
(
f
"[曲线显示] 错误: 曲线数据为空,X数据点: {len(self.curve_data_x)}, Y数据点: {len(self.curve_data_y)}"
)
self
.
_showCurveAsText
()
return
print
(
f
"[曲线显示] 曲线数据检查通过,数据点数: {len(self.curve_data_x)}"
)
if
not
PYQTGRAPH_AVAILABLE
:
print
(
"[曲线显示] PyQtGraph不可用,使用文本显示"
)
self
.
_showCurveAsText
()
return
# 生成曲线图表的HTML内容
print
(
"[曲线显示] 开始生成HTML内容"
)
curve_html
=
self
.
_generateCurveHTML
()
if
not
curve_html
:
print
(
"[曲线显示] 错误: HTML内容生成失败"
)
self
.
_showCurveAsText
()
return
# 在显示面板中显示曲线HTML
if
hasattr
(
self
,
'display_panel'
)
and
hasattr
(
self
,
'display_layout'
):
self
.
display_panel
.
setHtml
(
curve_html
)
self
.
display_layout
.
setCurrentWidget
(
self
.
display_panel
)
print
(
"[曲线显示] 曲线已显示在测试页面中"
)
else
:
print
(
"[曲线显示] 错误: 缺少display_panel或display_layout属性"
)
self
.
_showCurveAsText
()
except
Exception
as
e
:
print
(
f
"[曲线显示] 在测试页面显示曲线失败: {e}"
)
import
traceback
traceback
.
print_exc
()
# 降级到文本显示
self
.
_showCurveAsText
()
def
_generateCurveHTML
(
self
):
"""生成曲线的HTML内容"""
try
:
print
(
"[曲线HTML] 开始生成HTML内容"
)
# 保存曲线图片到临时文件
import
tempfile
import
os
temp_dir
=
tempfile
.
gettempdir
()
curve_image_path
=
os
.
path
.
join
(
temp_dir
,
"test_curve_display.png"
)
print
(
f
"[曲线HTML] 临时图片路径: {curve_image_path}"
)
# 优先尝试使用matplotlib生成曲线(更可靠)
if
self
.
_createMatplotlibCurve
(
curve_image_path
):
print
(
"[曲线HTML] matplotlib曲线生成成功"
)
# 使用PyQtGraph导出曲线图片
elif
hasattr
(
self
,
'curve_plot_widget'
)
and
self
.
curve_plot_widget
:
print
(
"[曲线HTML] 使用现有的curve_plot_widget导出图片"
)
try
:
# 使用样式管理器的配置
from
widgets.style_manager
import
CurveDisplayStyleManager
chart_width
,
chart_height
=
CurveDisplayStyleManager
.
getChartSize
()
exporter
=
pg
.
exporters
.
ImageExporter
(
self
.
curve_plot_widget
.
plotItem
)
exporter
.
parameters
()[
'width'
]
=
chart_width
exporter
.
parameters
()[
'height'
]
=
chart_height
exporter
.
export
(
curve_image_path
)
print
(
"[曲线HTML] 现有widget图片导出成功"
)
except
Exception
as
e
:
print
(
f
"[曲线HTML] 现有widget图片导出失败: {e}"
)
self
.
_createTempCurvePlot
(
curve_image_path
)
else
:
print
(
"[曲线HTML] 没有现有的curve_plot_widget,创建临时plot"
)
# 如果没有现有的plot widget,创建一个临时的
self
.
_createTempCurvePlot
(
curve_image_path
)
# 检查图片是否成功生成
if
not
os
.
path
.
exists
(
curve_image_path
):
print
(
f
"[曲线HTML] 错误: 图片文件未生成: {curve_image_path}"
)
return
self
.
_getFallbackCurveHTML
()
print
(
f
"[曲线HTML] 图片生成成功,文件大小: {os.path.getsize(curve_image_path)} bytes"
)
# 生成统计信息
data_count
=
len
(
self
.
curve_data_x
)
stats_html
=
""
if
data_count
==
1
:
# 图片测试
liquid_level
=
self
.
curve_data_y
[
0
]
stats_html
=
f
"""
<div style="margin-bottom: 15px; padding: 10px; background: #e8f4fd; border: 1px solid #bee5eb; border-radius: 5px;">
<h4 style="margin: 0 0 8px 0; color: #0c5460;">图片测试结果</h4>
<p style="margin: 0; color: #0c5460;"><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
</div>
"""
else
:
# 视频测试
min_level
=
min
(
self
.
curve_data_y
)
max_level
=
max
(
self
.
curve_data_y
)
avg_level
=
sum
(
self
.
curve_data_y
)
/
len
(
self
.
curve_data_y
)
stats_html
=
f
"""
<div style="margin-bottom: 15px; padding: 10px; background: #e8f4fd; border: 1px solid #bee5eb; border-radius: 5px;">
<h4 style="margin: 0 0 8px 0; color: #0c5460;">视频测试结果统计</h4>
<p style="margin: 2px 0; color: #0c5460;"><strong>数据点数:</strong> {data_count} 个</p>
<p style="margin: 2px 0; color: #0c5460;"><strong>液位范围:</strong> {min_level:.1f} - {max_level:.1f} mm</p>
<p style="margin: 2px 0; color: #0c5460;"><strong>平均液位:</strong> {avg_level:.1f} mm</p>
</div>
"""
# 使用统一的样式管理器生成HTML内容
from
widgets.style_manager
import
CurveDisplayStyleManager
html_content
=
CurveDisplayStyleManager
.
generateCurveHTML
(
curve_image_path
,
stats_html
)
return
html_content
except
Exception
as
e
:
print
(
f
"[曲线HTML] 生成HTML失败: {e}"
)
return
self
.
_getFallbackCurveHTML
()
def
_createMatplotlibCurve
(
self
,
output_path
):
"""使用matplotlib创建曲线图"""
try
:
print
(
"[matplotlib曲线] 开始使用matplotlib生成曲线"
)
import
matplotlib.pyplot
as
plt
import
matplotlib
matplotlib
.
use
(
'Agg'
)
# 使用非交互式后端
# 验证数据
if
not
hasattr
(
self
,
'curve_data_x'
)
or
not
hasattr
(
self
,
'curve_data_y'
):
print
(
"[matplotlib曲线] 错误: 缺少曲线数据"
)
return
False
if
len
(
self
.
curve_data_x
)
==
0
or
len
(
self
.
curve_data_y
)
==
0
:
print
(
f
"[matplotlib曲线] 错误: 曲线数据为空"
)
return
False
print
(
f
"[matplotlib曲线] 数据验证通过,X点数: {len(self.curve_data_x)}, Y点数: {len(self.curve_data_y)}"
)
# 设置中文字体
plt
.
rcParams
[
'font.sans-serif'
]
=
[
'Microsoft YaHei'
,
'SimHei'
,
'Arial'
]
plt
.
rcParams
[
'axes.unicode_minus'
]
=
False
# 使用样式管理器的配置
from
widgets.style_manager
import
CurveDisplayStyleManager
chart_width
,
chart_height
=
CurveDisplayStyleManager
.
getChartSize
()
chart_dpi
=
CurveDisplayStyleManager
.
getChartDPI
()
bg_color
=
CurveDisplayStyleManager
.
getPlotBackgroundColor
()
# 创建图形,使用统一的尺寸和DPI
fig
,
ax
=
plt
.
subplots
(
figsize
=
(
chart_width
/
100
,
chart_height
/
100
),
dpi
=
chart_dpi
)
# 绘制曲线
ax
.
plot
(
self
.
curve_data_x
,
self
.
curve_data_y
,
'b-'
,
linewidth
=
2.5
,
marker
=
'o'
,
markersize
=
5
,
markerfacecolor
=
'white'
,
markeredgecolor
=
'blue'
,
markeredgewidth
=
1.5
)
# 设置标签和标题
ax
.
set_xlabel
(
'帧序号'
,
fontsize
=
12
)
ax
.
set_ylabel
(
'液位高度 (mm)'
,
fontsize
=
12
)
ax
.
set_title
(
'液位检测曲线'
,
fontsize
=
14
,
fontweight
=
'bold'
,
pad
=
20
)
# 设置网格
ax
.
grid
(
True
,
alpha
=
0.3
,
linestyle
=
'--'
)
# 设置背景色,使用样式管理器的颜色
ax
.
set_facecolor
(
bg_color
)
fig
.
patch
.
set_facecolor
(
bg_color
)
# 优化布局,减少边距
plt
.
tight_layout
(
pad
=
1.0
)
# 保存图片,优化参数
plt
.
savefig
(
output_path
,
bbox_inches
=
'tight'
,
pad_inches
=
0.2
,
facecolor
=
bg_color
,
edgecolor
=
'none'
,
dpi
=
100
)
plt
.
close
(
fig
)
print
(
f
"[matplotlib曲线] 曲线图片已保存: {output_path}"
)
return
True
except
Exception
as
e
:
print
(
f
"[matplotlib曲线] 创建失败: {e}"
)
import
traceback
traceback
.
print_exc
()
return
False
def
_createTempCurvePlot
(
self
,
output_path
):
"""创建临时曲线图并保存"""
try
:
print
(
"[临时曲线] 开始创建临时曲线图"
)
import
pyqtgraph
as
pg
# 验证数据
if
not
hasattr
(
self
,
'curve_data_x'
)
or
not
hasattr
(
self
,
'curve_data_y'
):
print
(
"[临时曲线] 错误: 缺少曲线数据"
)
self
.
_createPlaceholderImage
(
output_path
)
return
if
len
(
self
.
curve_data_x
)
==
0
or
len
(
self
.
curve_data_y
)
==
0
:
print
(
f
"[临时曲线] 错误: 曲线数据为空"
)
self
.
_createPlaceholderImage
(
output_path
)
return
print
(
f
"[临时曲线] 数据验证通过,X点数: {len(self.curve_data_x)}, Y点数: {len(self.curve_data_y)}"
)
# 创建临时的plot widget
temp_plot
=
pg
.
PlotWidget
()
temp_plot
.
setBackground
(
'#f8f9fa'
)
temp_plot
.
showGrid
(
x
=
True
,
y
=
True
,
alpha
=
0.3
)
temp_plot
.
setLabel
(
'left'
,
'液位高度'
,
units
=
'mm'
)
temp_plot
.
setLabel
(
'bottom'
,
'帧序号'
)
temp_plot
.
setTitle
(
'液位检测曲线'
,
color
=
'#495057'
,
size
=
'12pt'
)
print
(
"[临时曲线] PlotWidget创建成功,开始绘制曲线"
)
# 绘制曲线
temp_plot
.
plot
(
self
.
curve_data_x
,
self
.
curve_data_y
,
pen
=
pg
.
mkPen
(
color
=
'#1f77b4'
,
width
=
2
),
name
=
'液位高度'
)
print
(
"[临时曲线] 曲线绘制完成,开始导出图片"
)
# 使用样式管理器的配置导出图片
from
widgets.style_manager
import
CurveDisplayStyleManager
chart_width
,
chart_height
=
CurveDisplayStyleManager
.
getChartSize
()
bg_color
=
CurveDisplayStyleManager
.
getPlotBackgroundColor
()
# 导出图片(使用统一尺寸)
exporter
=
pg
.
exporters
.
ImageExporter
(
temp_plot
.
plotItem
)
exporter
.
parameters
()[
'width'
]
=
chart_width
exporter
.
parameters
()[
'height'
]
=
chart_height
# 设置背景色
temp_plot
.
setBackground
(
bg_color
)
exporter
.
export
(
output_path
)
print
(
f
"[临时曲线] 曲线图片已保存: {output_path}"
)
except
Exception
as
e
:
print
(
f
"[临时曲线] 创建失败: {e}"
)
import
traceback
traceback
.
print_exc
()
# 创建一个简单的占位图片
self
.
_createPlaceholderImage
(
output_path
)
def
_createPlaceholderImage
(
self
,
output_path
):
"""创建占位图片"""
try
:
from
PIL
import
Image
,
ImageDraw
,
ImageFont
import
matplotlib.pyplot
as
plt
import
numpy
as
np
# 如果有曲线数据,尝试用matplotlib绘制
if
hasattr
(
self
,
'curve_data_x'
)
and
hasattr
(
self
,
'curve_data_y'
)
and
len
(
self
.
curve_data_x
)
>
0
:
print
(
"[占位图片] 尝试使用matplotlib绘制曲线"
)
try
:
plt
.
figure
(
figsize
=
(
10
,
5
),
dpi
=
80
)
plt
.
plot
(
self
.
curve_data_x
,
self
.
curve_data_y
,
'b-'
,
linewidth
=
2
,
marker
=
'o'
,
markersize
=
4
)
plt
.
xlabel
(
'帧序号'
)
plt
.
ylabel
(
'液位高度 (mm)'
)
plt
.
title
(
'液位检测曲线'
)
plt
.
grid
(
True
,
alpha
=
0.3
)
plt
.
tight_layout
()
plt
.
savefig
(
output_path
,
bbox_inches
=
'tight'
,
pad_inches
=
0.1
,
facecolor
=
'#f8f9fa'
)
plt
.
close
()
print
(
f
"[占位图片] matplotlib曲线已保存: {output_path}"
)
return
except
Exception
as
e
:
print
(
f
"[占位图片] matplotlib绘制失败: {e}"
)
# 创建简化的占位图片(更小的尺寸,减少白色背景)
img
=
Image
.
new
(
'RGB'
,
(
600
,
300
),
'#f8f9fa'
)
draw
=
ImageDraw
.
Draw
(
img
)
# 绘制边框
draw
.
rectangle
([
0
,
0
,
599
,
299
],
outline
=
'#dee2e6'
,
width
=
1
)
# 绘制文本
try
:
font
=
ImageFont
.
truetype
(
"C:/Windows/Fonts/msyh.ttc"
,
18
)
except
:
font
=
ImageFont
.
load_default
()
text
=
"曲线生成失败,请检查数据"
text_bbox
=
draw
.
textbbox
((
0
,
0
),
text
,
font
=
font
)
text_width
=
text_bbox
[
2
]
-
text_bbox
[
0
]
text_height
=
text_bbox
[
3
]
-
text_bbox
[
1
]
x
=
(
600
-
text_width
)
//
2
y
=
(
300
-
text_height
)
//
2
draw
.
text
((
x
,
y
),
text
,
fill
=
'#666666'
,
font
=
font
)
img
.
save
(
output_path
)
print
(
f
"[占位图片] 已创建: {output_path}"
)
except
Exception
as
e
:
print
(
f
"[占位图片] 创建失败: {e}"
)
def
_getFallbackCurveHTML
(
self
):
"""获取降级的曲线HTML内容"""
data_count
=
len
(
self
.
curve_data_x
)
if
data_count
==
1
:
liquid_level
=
self
.
curve_data_y
[
0
]
return
f
"""
<div style="font-family: Arial, sans-serif; padding: 20px; background: #ffffff; color: #333333;">
<h3>图片测试结果</h3>
<p><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
<p style="color: #666; font-size: 12px;">注: 曲线图表生成失败,显示文本结果。</p>
</div>
"""
else
:
min_level
=
min
(
self
.
curve_data_y
)
max_level
=
max
(
self
.
curve_data_y
)
avg_level
=
sum
(
self
.
curve_data_y
)
/
len
(
self
.
curve_data_y
)
return
f
"""
<div style="font-family: Arial, sans-serif; padding: 20px; background: #ffffff; color: #333333;">
<h3>视频测试结果统计</h3>
<p><strong>数据点数:</strong> {data_count} 个</p>
<p><strong>液位范围:</strong> {min_level:.1f} - {max_level:.1f} mm</p>
<p><strong>平均液位:</strong> {avg_level:.1f} mm</p>
<p style="color: #666; font-size: 12px;">注: 曲线图表生成失败,显示文本结果。</p>
</div>
"""
def
_showCurveAsText
(
self
):
"""以文本形式显示曲线结果"""
try
:
fallback_html
=
self
.
_getFallbackCurveHTML
()
if
hasattr
(
self
,
'display_panel'
)
and
hasattr
(
self
,
'display_layout'
):
self
.
display_panel
.
setHtml
(
fallback_html
)
self
.
display_layout
.
setCurrentWidget
(
self
.
display_panel
)
print
(
"[曲线显示] 以文本形式显示曲线结果"
)
except
Exception
as
e
:
print
(
f
"[曲线显示] 文本显示也失败: {e}"
)
def
_showNoCurveMessage
(
self
):
"""显示无曲线数据的提示"""
QtWidgets
.
QMessageBox
.
information
(
...
...
@@ -1874,6 +2160,25 @@ class TrainingPage(QtWidgets.QWidget):
"测试完成后即可查看曲线结果。"
)
def
_testCurveGeneration
(
self
):
"""测试曲线生成功能(调试用)"""
try
:
print
(
"[曲线测试] 开始测试曲线生成功能"
)
# 创建测试数据
self
.
curve_data_x
=
[
0
,
1
,
2
,
3
,
4
]
self
.
curve_data_y
=
[
25.0
,
26.5
,
24.8
,
27.2
,
25.9
]
print
(
f
"[曲线测试] 测试数据创建完成,X: {self.curve_data_x}, Y: {self.curve_data_y}"
)
# 测试曲线显示
self
.
_showCurveInTestPage
()
except
Exception
as
e
:
print
(
f
"[曲线测试] 测试失败: {e}"
)
import
traceback
traceback
.
print_exc
()
def
enableViewCurveButton
(
self
):
"""启用查看曲线按钮(测试完成后调用)"""
try
:
...
...
@@ -1916,6 +2221,56 @@ class TrainingPage(QtWidgets.QWidget):
return
"template_3"
return
None
def
refreshModelLists
(
self
):
"""刷新模型下拉菜单列表"""
try
:
# 导入模型集页面以获取模型列表
from
.modelset_page
import
ModelSetPage
# 获取detection_model目录下的所有模型
models
=
ModelSetPage
.
getDetectionModels
()
# 刷新基础模型下拉菜单
if
hasattr
(
self
,
'base_model_combo'
):
current_base
=
self
.
base_model_combo
.
currentText
()
self
.
base_model_combo
.
clear
()
for
model
in
models
:
display_name
=
model
[
'name'
]
# 只显示模型名称,不显示文件大小
self
.
base_model_combo
.
addItem
(
display_name
,
model
[
'path'
])
# 尝试恢复之前的选择
if
current_base
:
index
=
self
.
base_model_combo
.
findText
(
current_base
)
if
index
>=
0
:
self
.
base_model_combo
.
setCurrentIndex
(
index
)
# 刷新测试模型下拉菜单
if
hasattr
(
self
,
'test_model_combo'
):
current_test
=
self
.
test_model_combo
.
currentText
()
self
.
test_model_combo
.
clear
()
for
model
in
models
:
display_name
=
model
[
'name'
]
# 只显示模型名称,不显示文件大小
self
.
test_model_combo
.
addItem
(
display_name
,
model
[
'path'
])
# 尝试恢复之前的选择
if
current_test
:
index
=
self
.
test_model_combo
.
findText
(
current_test
)
if
index
>=
0
:
self
.
test_model_combo
.
setCurrentIndex
(
index
)
print
(
f
"[模型同步] 已刷新模型列表,共 {len(models)} 个模型"
)
except
Exception
as
e
:
print
(
f
"[错误] 刷新模型列表失败: {e}"
)
import
traceback
traceback
.
print_exc
()
def
refreshBaseModelList
(
self
):
"""刷新基础模型下拉菜单(兼容旧接口)"""
self
.
refreshModelLists
()
# 测试代码
if
__name__
==
'__main__'
:
...
...
@@ -2073,3 +2428,4 @@ class TrainingNotesDialog(QtWidgets.QDialog):
def
getNotesContent
(
self
):
"""获取笔记内容"""
return
self
.
text_edit
.
toPlainText
()
.
strip
()
widgets/style_manager.py
View file @
7e155e88
...
...
@@ -941,8 +941,8 @@ class BackgroundStyleManager:
"""全局背景颜色管理器"""
# 全局背景颜色配置
GLOBAL_BACKGROUND_COLOR
=
"
white"
# 统一白色背景
GLOBAL_BACKGROUND_STYLE
=
"background-color:
white
;"
GLOBAL_BACKGROUND_COLOR
=
"
#f8f9fa"
# 统一浅灰色背景,消除纯白色
GLOBAL_BACKGROUND_STYLE
=
"background-color:
#f8f9fa
;"
@classmethod
def
applyToWidget
(
cls
,
widget
):
...
...
@@ -1208,6 +1208,112 @@ def setMenuBarHoverColor(color):
# ============================================================================
# 曲线显示样式管理器 (Curve Display Style Manager)
# ============================================================================
class
CurveDisplayStyleManager
:
"""曲线显示样式管理器"""
# 统一的颜色配置
BACKGROUND_COLOR
=
"#f8f9fa"
# 统一背景色
PLOT_BACKGROUND_COLOR
=
"#f8f9fa"
# 图表背景色
CONTAINER_BACKGROUND_COLOR
=
"#ffffff"
# 容器背景色
BORDER_COLOR
=
"#dee2e6"
# 边框颜色
TEXT_COLOR
=
"#333333"
# 文本颜色
# 图表尺寸配置
CHART_WIDTH
=
600
CHART_HEIGHT
=
350
CHART_DPI
=
100
@classmethod
def
getBackgroundColor
(
cls
):
"""获取统一背景色"""
return
cls
.
BACKGROUND_COLOR
@classmethod
def
getPlotBackgroundColor
(
cls
):
"""获取图表背景色"""
return
cls
.
PLOT_BACKGROUND_COLOR
@classmethod
def
getContainerBackgroundColor
(
cls
):
"""获取容器背景色"""
return
cls
.
CONTAINER_BACKGROUND_COLOR
@classmethod
def
getBorderColor
(
cls
):
"""获取边框颜色"""
return
cls
.
BORDER_COLOR
@classmethod
def
getTextColor
(
cls
):
"""获取文本颜色"""
return
cls
.
TEXT_COLOR
@classmethod
def
getChartSize
(
cls
):
"""获取图表尺寸"""
return
cls
.
CHART_WIDTH
,
cls
.
CHART_HEIGHT
@classmethod
def
getChartDPI
(
cls
):
"""获取图表DPI"""
return
cls
.
CHART_DPI
@classmethod
def
generateCurveHTML
(
cls
,
curve_image_path
,
stats_html
=
""
):
"""生成统一样式的曲线HTML"""
return
f
"""
<div style="font-family: 'Microsoft YaHei', 'SimHei', Arial, sans-serif; padding: 20px; background: {cls.BACKGROUND_COLOR}; color: {cls.TEXT_COLOR};">
<div style="margin-bottom: 20px;">
<h3 style="margin: 0 0 15px 0; color: {cls.TEXT_COLOR}; font-size: 18px; font-weight: 600;">液位检测曲线结果</h3>
</div>
{stats_html}
<div style="text-align: center; margin-bottom: 15px; background: {cls.CONTAINER_BACKGROUND_COLOR}; padding: 8px; border-radius: 6px; border: 1px solid {cls.BORDER_COLOR};">
<img src="file:///{curve_image_path.replace(chr(92), '/')}" style="max-width: 100
%
; height: auto; border-radius: 4px;">
</div>
<div style="padding: 10px; background: {cls.CONTAINER_BACKGROUND_COLOR}; border: 1px solid {cls.BORDER_COLOR}; border-radius: 5px; font-size: 12px; color: #666;">
<p style="margin: 0;"><strong>说明:</strong> 曲线显示了液位检测的结果变化趋势。X轴表示帧序号,Y轴表示液位高度(毫米)。</p>
</div>
</div>
"""
@classmethod
def
setBackgroundColor
(
cls
,
color
):
"""设置背景颜色"""
cls
.
BACKGROUND_COLOR
=
color
cls
.
PLOT_BACKGROUND_COLOR
=
color
@classmethod
def
setChartSize
(
cls
,
width
,
height
):
"""设置图表尺寸"""
cls
.
CHART_WIDTH
=
width
cls
.
CHART_HEIGHT
=
height
# 曲线显示样式管理器便捷函数
def
getCurveBackgroundColor
():
"""获取曲线背景色的便捷函数"""
return
CurveDisplayStyleManager
.
getBackgroundColor
()
def
getCurvePlotBackgroundColor
():
"""获取曲线图表背景色的便捷函数"""
return
CurveDisplayStyleManager
.
getPlotBackgroundColor
()
def
getCurveChartSize
():
"""获取曲线图表尺寸的便捷函数"""
return
CurveDisplayStyleManager
.
getChartSize
()
def
generateCurveHTML
(
curve_image_path
,
stats_html
=
""
):
"""生成曲线HTML的便捷函数"""
return
CurveDisplayStyleManager
.
generateCurveHTML
(
curve_image_path
,
stats_html
)
# ============================================================================
# 测试代码
# ============================================================================
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment