Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
O
Oil_Level_Recognition_System
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
柳青君
Oil_Level_Recognition_System
Commits
21a67d1a
Commit
21a67d1a
authored
Nov 27, 2025
by
Yuhaibo
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
1
parent
e491a327
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
574 additions
and
179 deletions
+574
-179
model_set_handler.py
handlers/modelpage/model_set_handler.py
+179
-28
model_training_handler.py
handlers/modelpage/model_training_handler.py
+72
-120
model_trainingworker_handler.py
handlers/modelpage/model_trainingworker_handler.py
+215
-4
modelset_page.py
widgets/modelpage/modelset_page.py
+17
-9
training_page.py
widgets/modelpage/training_page.py
+91
-18
No files found.
handlers/modelpage/model_set_handler.py
View file @
21a67d1a
...
@@ -301,11 +301,9 @@ class ModelSetHandler:
...
@@ -301,11 +301,9 @@ class ModelSetHandler:
if
reply
==
QtWidgets
.
QMessageBox
.
Yes
:
if
reply
==
QtWidgets
.
QMessageBox
.
Yes
:
# 调用 ModelSetPage 的删除方法
# 调用 ModelSetPage 的删除方法
# deleteModel会触发deleteModelDataRequested信号,由deleteModelData处理
# deleteModelData会自动刷新训练页面
self
.
modelSetPage
.
deleteModel
(
model_name
)
self
.
modelSetPage
.
deleteModel
(
model_name
)
# 更新测试页面的模型列表
if
hasattr
(
self
,
'testModelPage'
):
self
.
testModelPage
.
loadModelsFromModelSetPage
(
self
.
modelSetPage
)
else
:
else
:
QtWidgets
.
QMessageBox
.
warning
(
self
,
"错误"
,
"模型集页面未初始化"
)
QtWidgets
.
QMessageBox
.
warning
(
self
,
"错误"
,
"模型集页面未初始化"
)
...
@@ -513,7 +511,7 @@ class ModelSetHandler:
...
@@ -513,7 +511,7 @@ class ModelSetHandler:
return
models
return
models
def
_scanModelDirectory
(
self
):
def
_scanModelDirectory
(
self
):
"""扫描模型目录获取所有模型文件"""
"""扫描模型目录获取所有模型文件
(增强版:支持config.yaml和按优先级选择模型)
"""
models
=
[]
models
=
[]
try
:
try
:
...
@@ -521,45 +519,140 @@ class ModelSetHandler:
...
@@ -521,45 +519,140 @@ class ModelSetHandler:
current_dir
=
Path
(
__file__
)
.
parent
.
parent
.
parent
current_dir
=
Path
(
__file__
)
.
parent
.
parent
.
parent
model_dir
=
current_dir
/
"database"
/
"model"
/
"train_model"
model_dir
=
current_dir
/
"database"
/
"model"
/
"train_model"
print
(
f
"[模型扫描] 扫描目录: {model_dir}"
)
if
not
model_dir
.
exists
():
if
not
model_dir
.
exists
():
print
(
f
"[模型扫描] 目录不存在"
)
return
models
return
models
all_items
=
list
(
model_dir
.
iterdir
())
print
(
f
"[模型扫描] 目录存在: {model_dir.exists()}"
)
# 按数字排序子目录(降序,最新的在前)
all_subdirs
=
[
d
for
d
in
model_dir
.
iterdir
()
if
d
.
is_dir
()]
print
(
f
"[模型扫描] 找到子目录数量: {len(all_subdirs)}"
)
print
(
f
"[模型扫描] 子目录列表: {[d.name for d in all_subdirs]}"
)
sorted_subdirs
=
sorted
(
model_dir
.
iterdir
(),
key
=
lambda
x
:
x
.
name
if
x
.
is_dir
()
else
''
)
digit_subdirs
=
[
d
for
d
in
all_subdirs
if
d
.
name
.
isdigit
()]
print
(
f
"[模型扫描] 数字子目录数量: {len(digit_subdirs)}"
)
sorted_subdirs
=
sorted
(
digit_subdirs
,
key
=
lambda
x
:
int
(
x
.
name
),
reverse
=
True
)
print
(
f
"[模型扫描] 数字子目录: {[d.name for d in sorted_subdirs]}"
)
for
subdir
in
sorted_subdirs
:
for
subdir
in
sorted_subdirs
:
if
subdir
.
is_dir
():
print
(
f
"[模型扫描] 处理子目录: {subdir.name}"
)
subdir_files
=
list
(
subdir
.
iterdir
())
dat_files
=
list
(
subdir
.
glob
(
"*.dat"
))
# 尝试读取 config.yaml 获取详细信息
config_file
=
subdir
/
"config.yaml"
model_config
=
None
for
model_file
in
sorted
(
dat_files
):
if
config_file
.
exists
():
model_info
=
{
try
:
'name'
:
f
"模型-{subdir.name}-{model_file.stem}"
,
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
'path'
:
str
(
model_file
),
model_config
=
yaml
.
safe_load
(
f
)
'subdir'
:
subdir
.
name
,
except
Exception
as
e
:
'source'
:
'scan'
,
print
(
f
"[模型扫描] 读取config.yaml失败: {e}"
)
'format'
:
'dat'
}
# 检查是否有weights子目录(优先检查train/weights,然后weights)
models
.
append
(
model_info
)
train_weights_dir
=
subdir
/
"train"
/
"weights"
weights_dir
=
subdir
/
"weights"
if
train_weights_dir
.
exists
():
search_dir
=
train_weights_dir
print
(
f
"[模型扫描] 找到train/weights目录: {search_dir}"
)
elif
weights_dir
.
exists
():
search_dir
=
weights_dir
print
(
f
"[模型扫描] 找到weights目录: {search_dir}"
)
else
:
search_dir
=
subdir
print
(
f
"[模型扫描] 使用根目录: {search_dir}"
)
print
(
f
"[模型扫描] 搜索目录: {search_dir}"
)
# 按优先级查找模型文件:best > last > epoch1
# 支持的扩展名:.dat, .pt, .template_*, 无扩展名
selected_model
=
None
# 优先级1: best模型
for
pattern
in
[
'best.*.dat'
,
'best.*.pt'
,
'best.template_*'
,
'best.*'
]:
if
selected_model
:
break
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
():
# 检查文件名是否匹配模式
if
file
.
name
.
startswith
(
'best.'
)
and
not
file
.
name
.
endswith
(
'.pt'
):
selected_model
=
file
print
(
f
"[模型扫描] 找到best模型: {file.name}"
)
break
# 优先级2: last模型(如果没有best)
if
not
selected_model
:
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
file
.
name
.
startswith
(
'last.'
)
and
not
file
.
name
.
endswith
(
'.pt'
):
selected_model
=
file
print
(
f
"[模型扫描] 找到last模型: {file.name}"
)
break
pt_files
=
list
(
subdir
.
glob
(
"*.pt"
))
# 优先级3: epoch1模型(如果没有best和last)
if
not
selected_model
:
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
file
.
name
.
startswith
(
'epoch1.'
)
and
not
file
.
name
.
endswith
(
'.pt'
):
selected_model
=
file
print
(
f
"[模型扫描] 找到epoch1模型: {file.name}"
)
break
# 如果都没找到,尝试查找任何非.pt文件
if
not
selected_model
:
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
not
file
.
name
.
endswith
(
'.pt'
)
and
not
file
.
name
.
endswith
(
'.txt'
)
and
not
file
.
name
.
endswith
(
'.yaml'
):
selected_model
=
file
print
(
f
"[模型扫描] 找到其他模型: {file.name}"
)
break
# 如果找到了模型文件,添加到列表
if
selected_model
:
# 从 config.yaml 获取信息,或使用默认值
if
model_config
:
model_name
=
model_config
.
get
(
'name'
,
f
"训练模型-{subdir.name}"
)
description
=
model_config
.
get
(
'description'
,
''
)
training_date
=
model_config
.
get
(
'training_date'
,
''
)
epochs
=
model_config
.
get
(
'epochs'
,
''
)
else
:
model_name
=
selected_model
.
stem
# 使用文件名(不含扩展名)作为模型名
description
=
f
"来自目录 {subdir.name}"
training_date
=
''
epochs
=
''
# 获取文件格式
file_ext
=
selected_model
.
suffix
.
lstrip
(
'.'
)
if
not
file_ext
:
# 处理无扩展名的情况(如 best.template_6543)
if
'.'
in
selected_model
.
name
:
file_ext
=
selected_model
.
name
.
split
(
'.'
)[
-
1
]
else
:
file_ext
=
'unknown'
for
model_file
in
sorted
(
pt_files
):
model_info
=
{
model_info
=
{
'name'
:
f
"模型-{subdir.name}-{model_file.stem}"
,
'name'
:
model_name
,
'path'
:
str
(
model_file
),
'path'
:
str
(
selected_model
),
'subdir'
:
subdir
.
name
,
'subdir'
:
subdir
.
name
,
'source'
:
'scan'
,
'source'
:
'train_model'
,
'format'
:
'pt'
'format'
:
file_ext
,
'description'
:
description
,
'training_date'
:
training_date
,
'epochs'
:
epochs
,
'file_name'
:
selected_model
.
name
}
}
models
.
append
(
model_info
)
models
.
append
(
model_info
)
print
(
f
"[模型扫描] 添加模型: {model_name} ({selected_model.name})"
)
else
:
print
(
f
"[模型扫描] 子目录 {subdir.name} 中未找到有效模型"
)
except
Exception
as
e
:
except
Exception
as
e
:
import
traceback
import
traceback
traceback
.
print_exc
()
traceback
.
print_exc
()
print
(
f
"[模型扫描] 扫描异常: {e}"
)
print
(
f
"[模型扫描] 总共找到 {len(models)} 个模型"
)
return
models
return
models
def
_mergeModelInfo
(
self
,
channel_models
,
scanned_models
):
def
_mergeModelInfo
(
self
,
channel_models
,
scanned_models
):
...
@@ -581,9 +674,19 @@ class ModelSetHandler:
...
@@ -581,9 +674,19 @@ class ModelSetHandler:
all_models
.
append
(
model
)
all_models
.
append
(
model
)
seen_paths
.
add
(
path
)
seen_paths
.
add
(
path
)
# 确保有一个默认模型
# 确保只有一个默认模型
has_default
=
any
(
model
.
get
(
'is_default'
,
False
)
for
model
in
all_models
)
default_model_index
=
-
1
if
not
has_default
and
len
(
all_models
)
>
0
:
for
i
,
model
in
enumerate
(
all_models
):
if
model
.
get
(
'is_default'
,
False
):
if
default_model_index
==
-
1
:
# 第一个默认模型,保留
default_model_index
=
i
else
:
# 已经有默认模型了,取消这个模型的默认标记
model
[
'is_default'
]
=
False
# 如果没有默认模型,将第一个模型设为默认
if
default_model_index
==
-
1
and
len
(
all_models
)
>
0
:
all_models
[
0
][
'is_default'
]
=
True
all_models
[
0
][
'is_default'
]
=
True
return
all_models
return
all_models
...
@@ -803,18 +906,66 @@ class ModelSetHandler:
...
@@ -803,18 +906,66 @@ class ModelSetHandler:
# 从参数中删除
# 从参数中删除
if
model_name
in
self
.
modelSetPage
.
_model_params
:
if
model_name
in
self
.
modelSetPage
.
_model_params
:
model_params
=
self
.
modelSetPage
.
_model_params
[
model_name
]
model_path
=
model_params
.
get
(
'path'
,
''
)
# 删除模型文件和所在目录
if
model_path
and
os
.
path
.
exists
(
model_path
):
try
:
import
shutil
# 获取模型所在的目录(train_model/{数字ID}/)
model_dir
=
os
.
path
.
dirname
(
model_path
)
# 检查是否是train_model目录下的子目录
if
'train_model'
in
model_dir
:
# 删除整个模型目录
if
os
.
path
.
exists
(
model_dir
):
shutil
.
rmtree
(
model_dir
)
print
(
f
"[删除模型] 已删除模型目录: {model_dir}"
)
except
Exception
as
delete_error
:
print
(
f
"[删除模型] 删除模型文件失败: {delete_error}"
)
del
self
.
modelSetPage
.
_model_params
[
model_name
]
del
self
.
modelSetPage
.
_model_params
[
model_name
]
# 从配置文件中删除模型配置
# 从配置文件中删除模型配置
self
.
_removeModelFromConfig
(
model_name
)
self
.
_removeModelFromConfig
(
model_name
)
# 刷新训练页面的模型测试下拉框
self
.
_refreshTrainingPageModelList
()
return
True
return
True
return
False
return
False
except
Exception
as
e
:
except
Exception
as
e
:
print
(
f
"[删除模型] 删除失败: {e}"
)
import
traceback
traceback
.
print_exc
()
return
False
return
False
def
_refreshTrainingPageModelList
(
self
):
"""刷新训练页面的模型测试下拉框"""
try
:
# 通过主窗口访问训练页面
if
hasattr
(
self
,
'main_window'
):
# 尝试多种可能的属性名
training_page
=
None
if
hasattr
(
self
.
main_window
,
'trainingPage'
):
training_page
=
self
.
main_window
.
trainingPage
elif
hasattr
(
self
.
main_window
,
'training_page'
):
training_page
=
self
.
main_window
.
training_page
elif
hasattr
(
self
.
main_window
,
'testModelPage'
):
training_page
=
self
.
main_window
.
testModelPage
if
training_page
and
hasattr
(
training_page
,
'_loadTestModelOptions'
):
training_page
.
_loadTestModelOptions
()
print
(
"[删除模型] 已刷新训练页面的模型列表"
)
else
:
print
(
"[删除模型] 未找到训练页面或_loadTestModelOptions方法"
)
except
Exception
as
e
:
print
(
f
"[删除模型] 刷新训练页面失败: {e}"
)
def
getAllModelParams
(
self
):
def
getAllModelParams
(
self
):
"""获取所有模型参数"""
"""获取所有模型参数"""
if
hasattr
(
self
,
'modelSetPage'
)
and
self
.
modelSetPage
:
if
hasattr
(
self
,
'modelSetPage'
)
and
self
.
modelSetPage
:
...
...
handlers/modelpage/model_training_handler.py
View file @
21a67d1a
...
@@ -394,38 +394,27 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -394,38 +394,27 @@ class ModelTrainingHandler(ModelTestHandler):
self
.
_appendLog
(
" 模型升级成功完成!
\n
"
)
self
.
_appendLog
(
" 模型升级成功完成!
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
#
升级完成后进行pt转dat转换
#
获取weights目录和转换结果
weights_dir
=
None
weights_dir
=
None
converted_files
=
[]
if
self
.
training_worker
and
hasattr
(
self
.
training_worker
,
"training_report"
):
if
self
.
training_worker
and
hasattr
(
self
.
training_worker
,
"training_report"
):
weights_dir
=
self
.
training_worker
.
training_report
.
get
(
"weights_dir"
)
weights_dir
=
self
.
training_worker
.
training_report
.
get
(
"weights_dir"
)
converted_files
=
self
.
training_worker
.
training_report
.
get
(
"converted_dat_files"
,
[])
converted_files
=
[]
# 显示转换结果(转换已在TrainingWorker中完成)
if
self
.
current_exp_name
or
weights_dir
:
self
.
_appendLog
(
"
\n
正在将升级结果转换为dat格式...
\n
"
)
try
:
# 使用新的立即转换和清理方法
if
weights_dir
:
converted_files
=
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
else
:
converted_files
=
self
.
_convertTrainingmission_resultsToDat
(
self
.
current_exp_name
,
weights_dir
)
if
converted_files
:
if
converted_files
:
self
.
_appendLog
(
f
" 已将训练结果
转换为dat格式: {len(converted_files)}个文件
\n
"
)
self
.
_appendLog
(
f
"
\n
模型已
转换为dat格式: {len(converted_files)}个文件
\n
"
)
for
f
in
converted_files
:
for
f
in
converted_files
:
self
.
_appendLog
(
f
" - {f}
\n
"
)
self
.
_appendLog
(
f
" - {os.path.basename(f)}
\n
"
)
else
:
self
.
_appendLog
(
"转换完成但没有文件被转换
\n
"
)
except
Exception
as
convert_error
:
self
.
_appendLog
(
f
"[WARNING] 模型转换失败: {str(convert_error)}
\n
"
)
import
traceback
traceback
.
print_exc
()
# 保存训练日志到weights目录
# 保存训练日志到weights目录
if
weights_dir
:
self
.
_appendLog
(
"
\n
正在保存训练日志...
\n
"
)
self
.
_appendLog
(
"
\n
正在保存训练日志...
\n
"
)
log_file
=
self
.
_saveTrainingLogToWeightsDir
(
self
.
current_exp_name
,
weights_dir
)
log_file
=
self
.
_saveTrainingLogToWeightsDir
(
self
.
current_exp_name
,
weights_dir
)
if
log_file
:
if
log_file
:
self
.
_appendLog
(
f
" 训练日志已保存: {os.path.basename(log_file)}
\n
"
)
self
.
_appendLog
(
f
" 训练日志已保存: {os.path.basename(log_file)}
\n
"
)
else
:
self
.
_appendLog
(
" [WARNING] 训练日志保存失败
\n
"
)
# 将转换结果写入 training_report.json
# 将转换结果写入 training_report.json
try
:
try
:
...
@@ -508,28 +497,27 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -508,28 +497,27 @@ class ModelTrainingHandler(ModelTestHandler):
self
.
_appendLog
(
"训练已暂停
\n
"
)
self
.
_appendLog
(
"训练已暂停
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
self
.
_appendLog
(
"="
*
70
+
"
\n
"
)
#
转换已产生的模型为 dat 格式
#
获取weights目录和转换结果(转换已在TrainingWorker中完成)
converted_files
=
[]
converted_files
=
[]
weights_dir
=
None
weights_dir
=
None
if
self
.
current_exp_name
:
if
self
.
training_worker
and
hasattr
(
self
.
training_worker
,
"training_report"
)
:
train_root
=
get_train_dir
(
)
weights_dir
=
self
.
training_worker
.
training_report
.
get
(
"weights_dir"
)
weights_dir
=
os
.
path
.
join
(
train_root
,
"runs"
,
"train"
,
self
.
current_exp_name
,
"weights"
)
converted_files
=
self
.
training_worker
.
training_report
.
get
(
"converted_dat_files"
,
[]
)
# 转换模型文件 - 使用新的立即转换和清理方法
# 显示转换结果
self
.
_appendLog
(
"
\n
正在转换已产生的模型为dat格式...
\n
"
)
converted_files
=
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
if
converted_files
:
if
converted_files
:
self
.
_appendLog
(
f
"成功转换 {len(converted_files)} 个模型到.dat格式
\n
"
)
self
.
_appendLog
(
f
"
\n
模型已转换为dat格式: {len(converted_files)}个文件
\n
"
)
for
f
in
converted_files
:
for
f
in
converted_files
:
self
.
_appendLog
(
f
" - {f}
\n
"
)
self
.
_appendLog
(
f
" - {os.path.basename(f)}
\n
"
)
else
:
self
.
_appendLog
(
"[WARNING] 未找到可转换的模型文件
\n
"
)
# 保存训练日志到weights目录
# 保存训练日志到weights目录
if
weights_dir
:
self
.
_appendLog
(
"
\n
正在保存训练日志...
\n
"
)
self
.
_appendLog
(
"
\n
正在保存训练日志...
\n
"
)
log_file
=
self
.
_saveTrainingLogToWeightsDir
(
self
.
current_exp_name
,
weights_dir
)
log_file
=
self
.
_saveTrainingLogToWeightsDir
(
self
.
current_exp_name
,
weights_dir
)
if
log_file
:
if
log_file
:
self
.
_appendLog
(
f
"训练日志已保存: {os.path.basename(log_file)}
\n
"
)
self
.
_appendLog
(
f
"训练日志已保存: {os.path.basename(log_file)}
\n
"
)
else
:
self
.
_appendLog
(
"[WARNING] 训练日志保存失败
\n
"
)
# 用户停止训练时也要同步模型到模型集
# 用户停止训练时也要同步模型到模型集
self
.
_appendLog
(
"
\n
正在将模型同步到模型集管理...
\n
"
)
self
.
_appendLog
(
"
\n
正在将模型同步到模型集管理...
\n
"
)
...
@@ -2208,51 +2196,79 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -2208,51 +2196,79 @@ class ModelTrainingHandler(ModelTestHandler):
traceback
.
print_exc
()
traceback
.
print_exc
()
def
_saveModelToConfig
(
self
,
model_name
,
model_params
):
def
_saveModelToConfig
(
self
,
model_name
,
model_params
):
"""保存模型
到配置文件
"""
"""保存模型
配置文件(模型已经在train_model目录中)
"""
try
:
try
:
from
pathlib
import
Path
from
pathlib
import
Path
import
yaml
import
yaml
# 获取项目根目录
self
.
_appendLog
(
f
"
\n
开始保存模型配置...
\n
"
)
project_root
=
get_project_root
()
self
.
_appendLog
(
f
" 模型名称: {model_name}
\n
"
)
self
.
_appendLog
(
f
" 模型路径: {model_params['path']}
\n
"
)
# 模型目录结构:database/model/train_model/{model_name}/
# 获取模型所在目录(应该已经在train_model/{数字ID}/weights/中)
train_model_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'train_model'
)
model_path
=
model_params
[
'path'
]
os
.
makedirs
(
train_model_dir
,
exist_ok
=
True
)
if
not
os
.
path
.
exists
(
model_path
):
raise
FileNotFoundError
(
f
"模型文件不存在: {model_path}"
)
# 查找下一个可用的数字目录
# 获取weights目录
existing_dirs
=
[]
weights_dir
=
os
.
path
.
dirname
(
model_path
)
for
item
in
os
.
listdir
(
train_model_dir
):
# 获取模型ID目录(weights的父目录)
item_path
=
os
.
path
.
join
(
train_model_dir
,
item
)
model_dir
=
os
.
path
.
dirname
(
weights_dir
)
if
os
.
path
.
isdir
(
item_path
)
and
item
.
isdigit
():
existing_dirs
.
append
(
int
(
item
))
next_id
=
max
(
existing_dirs
)
+
1
if
existing_dirs
else
1
self
.
_appendLog
(
f
" 模型目录: {model_dir}
\n
"
)
model_dir
=
os
.
path
.
join
(
train_model_dir
,
str
(
next_id
))
os
.
makedirs
(
model_dir
,
exist_ok
=
True
)
# 复制模型文件到新目录,保持原有的命名规则
# 查找可用的模型文件
source_model_path
=
model_params
[
'path'
]
model_files_found
=
[]
source_filename
=
os
.
path
.
basename
(
source_model_path
)
dest_model_path
=
os
.
path
.
join
(
model_dir
,
source_filename
)
import
shutil
# 1. 查找best模型
shutil
.
copy2
(
source_model_path
,
dest_model_path
)
for
filename
in
os
.
listdir
(
weights_dir
):
if
filename
.
startswith
(
'best.'
)
and
not
filename
.
endswith
(
'.pt'
):
model_files_found
.
append
(
f
"best模型: {filename}"
)
# 更新路径指向best模型
if
'best'
in
filename
:
model_params
[
'path'
]
=
os
.
path
.
join
(
weights_dir
,
filename
)
break
# 更新模型参数中的路径
# 2. 查找last模型
model_params
[
'path'
]
=
dest_model_path
for
filename
in
os
.
listdir
(
weights_dir
):
if
filename
.
startswith
(
'last.'
)
and
not
filename
.
endswith
(
'.pt'
):
model_files_found
.
append
(
f
"last模型: {filename}"
)
break
# 3. 查找epoch1模型
for
filename
in
os
.
listdir
(
weights_dir
):
if
filename
.
startswith
(
'epoch1.'
)
and
not
filename
.
endswith
(
'.pt'
):
model_files_found
.
append
(
f
"第一轮模型: {filename}"
)
break
# 4. 查找训练日志
log_files_found
=
[]
for
filename
in
os
.
listdir
(
weights_dir
):
if
filename
.
startswith
(
'training_log_'
)
and
filename
.
endswith
(
'.txt'
):
log_files_found
.
append
(
filename
)
# 保存模型配置到 YAML 文件
# 保存模型配置到 YAML 文件
config_file
=
os
.
path
.
join
(
model_dir
,
'config.yaml'
)
config_file
=
os
.
path
.
join
(
model_dir
,
'config.yaml'
)
self
.
_appendLog
(
f
" 保存配置文件: {config_file}
\n
"
)
with
open
(
config_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
with
open
(
config_file
,
'w'
,
encoding
=
'utf-8'
)
as
f
:
yaml
.
dump
(
model_params
,
f
,
allow_unicode
=
True
,
default_flow_style
=
False
)
yaml
.
dump
(
model_params
,
f
,
allow_unicode
=
True
,
default_flow_style
=
False
)
self
.
_appendLog
(
f
" 模型已保存到: {model_dir}
\n
"
)
# 输出总结信息
self
.
_appendLog
(
f
"
\n
✅ 模型配置已保存: {model_dir}
\n
"
)
self
.
_appendLog
(
f
" 找到的文件:
\n
"
)
for
info
in
model_files_found
:
self
.
_appendLog
(
f
" - {info}
\n
"
)
if
log_files_found
:
for
log_file
in
log_files_found
:
self
.
_appendLog
(
f
" - 训练日志: {log_file}
\n
"
)
self
.
_appendLog
(
f
" - 配置文件: config.yaml
\n
"
)
except
Exception
as
e
:
except
Exception
as
e
:
self
.
_appendLog
(
f
"[ERROR] 保存模型配置失败: {str(e)}
\n
"
)
self
.
_appendLog
(
f
"
❌
[ERROR] 保存模型配置失败: {str(e)}
\n
"
)
import
traceback
import
traceback
traceback
.
print_exc
()
traceback
.
print_exc
()
self
.
_appendLog
(
f
"
\n
完整错误信息:
\n
"
)
self
.
_appendLog
(
traceback
.
format_exc
())
def
_refreshModelTestPage
(
self
):
def
_refreshModelTestPage
(
self
):
"""刷新模型测试页面的模型列表"""
"""刷新模型测试页面的模型列表"""
...
@@ -2619,70 +2635,6 @@ class ModelTrainingHandler(ModelTestHandler):
...
@@ -2619,70 +2635,6 @@ class ModelTrainingHandler(ModelTestHandler):
import
traceback
import
traceback
traceback
.
print_exc
()
traceback
.
print_exc
()
def
_addTrainedModelToModelSet
(
self
,
converted_files
,
weights_dir
):
"""将训练好的模型添加到模型集配置"""
try
:
if
not
converted_files
:
print
(
"[信息] 没有转换的模型文件,跳过添加到模型集"
)
return
print
(
f
"[信息] 正在将 {len(converted_files)} 个训练模型添加到模型集..."
)
# 获取项目根目录
project_root
=
get_project_root
()
# 为每个转换的模型文件创建模型信息
for
dat_file
in
converted_files
:
try
:
# 获取模型文件信息
model_path
=
Path
(
dat_file
)
model_name
=
f
"训练模型-{model_path.parent.name}-{model_path.stem}"
# 检查模型是否已存在
if
hasattr
(
self
,
'main_window'
)
and
hasattr
(
self
.
main_window
,
'modelSetPage'
):
existing_models
=
self
.
main_window
.
modelSetPage
.
getAllModels
()
if
any
(
model_name
in
model
for
model
in
existing_models
):
print
(
f
"[信息] 模型 {model_name} 已存在,跳过添加"
)
continue
# 创建模型参数
model_params
=
{
'name'
:
model_name
,
'type'
:
'YOLOv11'
,
'path'
:
str
(
dat_file
),
'config_path'
:
''
,
'description'
:
f
'训练于 {weights_dir}'
if
weights_dir
else
'自动训练生成的模型'
,
'size'
:
self
.
_getFileSize
(
str
(
dat_file
)),
'classes'
:
3
,
# liquid, foam, air
'input'
:
'640x640'
,
'confidence'
:
0.5
,
'iou'
:
0.45
,
'device'
:
'CUDA:0 (GPU)'
,
'batch_size'
:
16
,
'blur_training'
:
100
,
'epochs'
:
300
,
'workers'
:
8
,
'model_type'
:
'训练模型'
,
'source'
:
'training'
}
# 添加到模型集页面
if
hasattr
(
self
,
'main_window'
)
and
hasattr
(
self
.
main_window
,
'modelSetPage'
):
self
.
main_window
.
modelSetPage
.
_model_params
[
model_name
]
=
model_params
self
.
main_window
.
modelSetPage
.
addModelToList
(
model_name
)
print
(
f
"[信息] 已添加模型: {model_name}"
)
except
Exception
as
model_error
:
print
(
f
"[错误] 添加模型 {dat_file} 失败: {model_error}"
)
continue
print
(
"[信息] 训练模型添加到模型集完成"
)
except
Exception
as
e
:
print
(
f
"[错误] 添加训练模型到模型集失败: {e}"
)
import
traceback
traceback
.
print_exc
()
def
_getFileSize
(
self
,
file_path
):
def
_getFileSize
(
self
,
file_path
):
"""获取文件大小(格式化字符串)"""
"""获取文件大小(格式化字符串)"""
try
:
try
:
...
...
handlers/modelpage/model_trainingworker_handler.py
View file @
21a67d1a
...
@@ -39,6 +39,19 @@ except (ImportError, ValueError):
...
@@ -39,6 +39,19 @@ except (ImportError, ValueError):
sys
.
path
.
insert
(
0
,
str
(
project_root
))
sys
.
path
.
insert
(
0
,
str
(
project_root
))
from
database.config
import
get_project_root
,
get_temp_models_dir
,
get_train_dir
from
database.config
import
get_project_root
,
get_temp_models_dir
,
get_train_dir
# 导入模型转换工具
try
:
from
.tools.convert_pt_to_dat
import
FileConverter
as
PtToDatConverter
except
(
ImportError
,
ValueError
):
try
:
from
handlers.modelpage.tools.convert_pt_to_dat
import
FileConverter
as
PtToDatConverter
except
ImportError
:
import
sys
from
pathlib
import
Path
project_root
=
Path
(
__file__
)
.
parent
.
parent
.
parent
sys
.
path
.
insert
(
0
,
str
(
project_root
))
from
handlers.modelpage.tools.convert_pt_to_dat
import
FileConverter
as
PtToDatConverter
MODEL_FILE_SIGNATURE
=
b
'LDS_MODEL_FILE'
MODEL_FILE_SIGNATURE
=
b
'LDS_MODEL_FILE'
MODEL_FILE_VERSION
=
1
MODEL_FILE_VERSION
=
1
MODEL_ENCRYPTION_KEY
=
"liquid_detection_system_2024"
MODEL_ENCRYPTION_KEY
=
"liquid_detection_system_2024"
...
@@ -693,6 +706,27 @@ class TrainingWorker(QThread):
...
@@ -693,6 +706,27 @@ class TrainingWorker(QThread):
if
device_str
.
lower
()
in
[
'cpu'
,
'-1'
]:
if
device_str
.
lower
()
in
[
'cpu'
,
'-1'
]:
workers
=
0
# CPU模式下禁用多线程数据加载
workers
=
0
# CPU模式下禁用多线程数据加载
# 获取下一个可用的模型ID并创建目录
project_root
=
get_project_root
()
train_model_dir
=
os
.
path
.
join
(
project_root
,
'database'
,
'model'
,
'train_model'
)
os
.
makedirs
(
train_model_dir
,
exist_ok
=
True
)
# 查找下一个可用的数字ID
existing_dirs
=
[]
for
item
in
os
.
listdir
(
train_model_dir
):
item_path
=
os
.
path
.
join
(
train_model_dir
,
item
)
if
os
.
path
.
isdir
(
item_path
)
and
item
.
isdigit
():
existing_dirs
.
append
(
int
(
item
))
next_id
=
max
(
existing_dirs
)
+
1
if
existing_dirs
else
1
model_output_dir
=
os
.
path
.
join
(
train_model_dir
,
str
(
next_id
))
# 预先设置weights_dir(YOLO会在project目录下创建weights子目录)
weights_dir
=
os
.
path
.
join
(
model_output_dir
,
"weights"
)
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"模型将直接保存到: {model_output_dir}
\n
"
)
# 开始训练
# 开始训练
try
:
try
:
mission_results
=
model
.
train
(
mission_results
=
model
.
train
(
...
@@ -705,8 +739,8 @@ class TrainingWorker(QThread):
...
@@ -705,8 +739,8 @@ class TrainingWorker(QThread):
optimizer
=
self
.
training_params
[
'optimizer'
],
optimizer
=
self
.
training_params
[
'optimizer'
],
close_mosaic
=
self
.
training_params
[
'close_mosaic'
],
close_mosaic
=
self
.
training_params
[
'close_mosaic'
],
resume
=
self
.
training_params
[
'resume'
],
resume
=
self
.
training_params
[
'resume'
],
project
=
'database/train/runs/train'
,
project
=
model_output_dir
,
name
=
self
.
training_params
[
'exp_name'
],
name
=
''
,
# 空名称,直接使用project目录
single_cls
=
self
.
training_params
[
'single_cls'
],
single_cls
=
self
.
training_params
[
'single_cls'
],
cache
=
False
,
cache
=
False
,
pretrained
=
self
.
training_params
[
'pretrained'
],
pretrained
=
self
.
training_params
[
'pretrained'
],
...
@@ -726,6 +760,50 @@ class TrainingWorker(QThread):
...
@@ -726,6 +760,50 @@ class TrainingWorker(QThread):
self
.
log_output
.
emit
(
"等待当前epoch完成并保存模型...
\n
"
)
self
.
log_output
.
emit
(
"等待当前epoch完成并保存模型...
\n
"
)
time
.
sleep
(
2
)
# 给YOLO时间完成保存
time
.
sleep
(
2
)
# 给YOLO时间完成保存
training_success
=
True
# 标记为成功,因为这是用户主动停止
training_success
=
True
# 标记为成功,因为这是用户主动停止
# 获取实际保存目录并转换模型
try
:
save_dir
=
getattr
(
getattr
(
model
,
"trainer"
,
None
),
"save_dir"
,
None
)
self
.
log_output
.
emit
(
f
"
\n
[调试] model.trainer 存在: {hasattr(model, 'trainer')}
\n
"
)
self
.
log_output
.
emit
(
f
"[调试] save_dir 值: {save_dir}
\n
"
)
if
save_dir
:
save_dir_abs
=
os
.
path
.
abspath
(
str
(
save_dir
))
weights_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
save_dir_abs
,
"weights"
))
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"[调试] 实际保存目录: {save_dir_abs}
\n
"
)
self
.
log_output
.
emit
(
f
"[调试] 实际权重目录: {weights_dir}
\n
"
)
# 立即转换PT文件为DAT格式并删除PT文件
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
else
:
# 备用方案:使用预设的model_output_dir
self
.
log_output
.
emit
(
"
\n
[WARNING] 无法从trainer获取保存目录,使用预设目录
\n
"
)
if
'model_output_dir'
in
locals
():
# 查找实际的weights目录(可能在train子目录下)
possible_weights_dirs
=
[
os
.
path
.
join
(
model_output_dir
,
"train"
,
"weights"
),
os
.
path
.
join
(
model_output_dir
,
"weights"
)
]
for
possible_dir
in
possible_weights_dirs
:
if
os
.
path
.
exists
(
possible_dir
):
weights_dir
=
possible_dir
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"[调试] 找到权重目录: {weights_dir}
\n
"
)
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
break
else
:
self
.
log_output
.
emit
(
f
"[ERROR] 未找到权重目录,跳过转换
\n
"
)
else
:
self
.
log_output
.
emit
(
f
"[ERROR] model_output_dir 未定义,跳过转换
\n
"
)
except
Exception
as
convert_err
:
self
.
log_output
.
emit
(
f
"
\n
[ERROR] 转换过程出错: {convert_err}
\n
"
)
import
traceback
self
.
log_output
.
emit
(
traceback
.
format_exc
())
break
# 跳出重试循环
break
# 跳出重试循环
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -750,6 +828,50 @@ class TrainingWorker(QThread):
...
@@ -750,6 +828,50 @@ class TrainingWorker(QThread):
self
.
log_output
.
emit
(
"等待当前epoch完成并保存模型...
\n
"
)
self
.
log_output
.
emit
(
"等待当前epoch完成并保存模型...
\n
"
)
time
.
sleep
(
2
)
# 给YOLO时间完成保存
time
.
sleep
(
2
)
# 给YOLO时间完成保存
training_success
=
True
training_success
=
True
# 获取实际保存目录并转换模型
try
:
save_dir
=
getattr
(
getattr
(
model
,
"trainer"
,
None
),
"save_dir"
,
None
)
self
.
log_output
.
emit
(
f
"
\n
[调试] model.trainer 存在: {hasattr(model, 'trainer')}
\n
"
)
self
.
log_output
.
emit
(
f
"[调试] save_dir 值: {save_dir}
\n
"
)
if
save_dir
:
save_dir_abs
=
os
.
path
.
abspath
(
str
(
save_dir
))
weights_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
save_dir_abs
,
"weights"
))
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"[调试] 实际保存目录: {save_dir_abs}
\n
"
)
self
.
log_output
.
emit
(
f
"[调试] 实际权重目录: {weights_dir}
\n
"
)
# 立即转换PT文件为DAT格式并删除PT文件
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
else
:
# 备用方案:使用预设的model_output_dir
self
.
log_output
.
emit
(
"
\n
[WARNING] 无法从trainer获取保存目录,使用预设目录
\n
"
)
if
'model_output_dir'
in
locals
():
# 查找实际的weights目录(可能在train子目录下)
possible_weights_dirs
=
[
os
.
path
.
join
(
model_output_dir
,
"train"
,
"weights"
),
os
.
path
.
join
(
model_output_dir
,
"weights"
)
]
for
possible_dir
in
possible_weights_dirs
:
if
os
.
path
.
exists
(
possible_dir
):
weights_dir
=
possible_dir
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"[调试] 找到权重目录: {weights_dir}
\n
"
)
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
break
else
:
self
.
log_output
.
emit
(
f
"[ERROR] 未找到权重目录,跳过转换
\n
"
)
else
:
self
.
log_output
.
emit
(
f
"[ERROR] model_output_dir 未定义,跳过转换
\n
"
)
except
Exception
as
convert_err
:
self
.
log_output
.
emit
(
f
"
\n
[ERROR] 转换过程出错: {convert_err}
\n
"
)
import
traceback
self
.
log_output
.
emit
(
traceback
.
format_exc
())
break
break
# 训练成功
# 训练成功
...
@@ -763,11 +885,18 @@ class TrainingWorker(QThread):
...
@@ -763,11 +885,18 @@ class TrainingWorker(QThread):
weights_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
save_dir_abs
,
"weights"
))
weights_dir
=
os
.
path
.
abspath
(
os
.
path
.
join
(
save_dir_abs
,
"weights"
))
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
training_report
[
"weights_dir"
]
=
weights_dir
self
.
log_output
.
emit
(
f
"
\n
[调试] 实际保存目录: {save_dir_abs}
\n
"
)
self
.
log_output
.
emit
(
f
"[调试] 实际权重目录: {weights_dir}
\n
"
)
# 立即转换PT文件为DAT格式并删除PT文件
# 立即转换PT文件为DAT格式并删除PT文件
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
log_output
.
emit
(
"
\n
正在转换模型文件为DAT格式...
\n
"
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
self
.
_convertPtToDatAndCleanup
(
weights_dir
)
except
:
else
:
pass
self
.
log_output
.
emit
(
"
\n
[WARNING] 无法获取模型保存目录,跳过转换
\n
"
)
except
Exception
as
convert_err
:
self
.
log_output
.
emit
(
f
"
\n
[ERROR] 转换过程出错: {convert_err}
\n
"
)
import
traceback
self
.
log_output
.
emit
(
traceback
.
format_exc
())
break
# 跳出重试循环
break
# 跳出重试循环
except
RuntimeError
as
runtime_error
:
except
RuntimeError
as
runtime_error
:
...
@@ -1050,6 +1179,88 @@ class TrainingWorker(QThread):
...
@@ -1050,6 +1179,88 @@ class TrainingWorker(QThread):
except
Exception
as
e
:
except
Exception
as
e
:
pass
pass
def
_convertPtToDatAndCleanup
(
self
,
weights_dir
):
"""
转换PT文件为DAT格式并删除原始PT文件
Args:
weights_dir: 权重目录路径
"""
try
:
if
not
os
.
path
.
exists
(
weights_dir
):
self
.
log_output
.
emit
(
f
"[转换] 权重目录不存在: {weights_dir}
\n
"
)
return
self
.
log_output
.
emit
(
f
"[转换] 开始扫描权重目录: {weights_dir}
\n
"
)
# 创建转换器
converter
=
PtToDatConverter
(
key
=
MODEL_ENCRYPTION_KEY
)
# 查找所有.pt文件
pt_files
=
[]
for
filename
in
os
.
listdir
(
weights_dir
):
if
filename
.
endswith
(
'.pt'
):
pt_file_path
=
os
.
path
.
join
(
weights_dir
,
filename
)
pt_files
.
append
(
pt_file_path
)
if
not
pt_files
:
self
.
log_output
.
emit
(
f
"[转换] 未找到.pt文件
\n
"
)
return
self
.
log_output
.
emit
(
f
"[转换] 找到 {len(pt_files)} 个.pt文件
\n
"
)
# 转换每个.pt文件
converted_files
=
[]
for
pt_file
in
pt_files
:
try
:
filename
=
os
.
path
.
basename
(
pt_file
)
self
.
log_output
.
emit
(
f
"[转换] 正在转换: {filename}
\n
"
)
# 生成输出文件名(使用.dat扩展名)
exp_name
=
self
.
training_params
.
get
(
'exp_name'
,
''
)
base_name
=
os
.
path
.
splitext
(
filename
)[
0
]
if
exp_name
:
# 例如: best.pt -> best.template_1234.dat
output_filename
=
f
"{base_name}.{exp_name}.dat"
else
:
# 例如: best.pt -> best.dat
output_filename
=
f
"{base_name}.dat"
output_path
=
os
.
path
.
join
(
weights_dir
,
output_filename
)
self
.
log_output
.
emit
(
f
"[转换] 输入: {pt_file}
\n
"
)
self
.
log_output
.
emit
(
f
"[转换] 输出: {output_path}
\n
"
)
# 执行转换
converted_path
=
converter
.
convert_file
(
pt_file
,
output_path
)
converted_files
.
append
(
converted_path
)
self
.
log_output
.
emit
(
f
"[转换] ✓ 已转换: {output_filename}
\n
"
)
# 删除原始.pt文件
try
:
os
.
remove
(
pt_file
)
self
.
log_output
.
emit
(
f
"[转换] ✓ 已删除原始文件: {filename}
\n
"
)
except
Exception
as
del_error
:
self
.
log_output
.
emit
(
f
"[转换] ⚠ 删除原始文件失败: {filename} - {del_error}
\n
"
)
except
Exception
as
convert_error
:
self
.
log_output
.
emit
(
f
"[转换] ✗ 转换失败: {filename} - {convert_error}
\n
"
)
import
traceback
self
.
log_output
.
emit
(
f
"[转换] 错误详情:
\n
{traceback.format_exc()}
\n
"
)
continue
# 更新训练报告
self
.
training_report
[
"converted_dat_files"
]
=
converted_files
self
.
log_output
.
emit
(
f
"
\n
[转换] 完成!共转换 {len(converted_files)} 个文件
\n
"
)
except
Exception
as
e
:
self
.
log_output
.
emit
(
f
"[转换] 转换过程出错: {e}
\n
"
)
import
traceback
traceback
.
print_exc
()
def
stop_training
(
self
):
def
stop_training
(
self
):
"""停止训练"""
"""停止训练"""
self
.
is_running
=
False
self
.
is_running
=
False
widgets/modelpage/modelset_page.py
View file @
21a67d1a
...
@@ -104,11 +104,8 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -104,11 +104,8 @@ class ModelSetPage(QtWidgets.QWidget):
self
.
_connectSignals
()
self
.
_connectSignals
()
self
.
_setupShortcuts
()
self
.
_setupShortcuts
()
# 移除自动加载,改为手动加载模式
# 标记是否已加载过模型
# 用户可以通过以下方式加载模型:
self
.
_models_loaded
=
False
# 1. 按 F5 刷新
# 2. 右键点击空白区域刷新
# 3. 点击刷新按钮(如果有)
def
_initUI
(
self
):
def
_initUI
(
self
):
"""初始化UI - 简约labelme风格(三栏充实布局版)"""
"""初始化UI - 简约labelme风格(三栏充实布局版)"""
...
@@ -765,12 +762,14 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -765,12 +762,14 @@ class ModelSetPage(QtWidgets.QWidget):
# 重新添加所有模型
# 重新添加所有模型
for
model_name
in
self
.
_model_params
.
keys
():
for
model_name
in
self
.
_model_params
.
keys
():
self
.
model_set_list
.
addItem
(
model_name
)
# 构建显示文本
display_text
=
model_name
# 如果是默认模型,添加标记
# 如果是默认模型,添加标记
(使用统一格式)
if
model_name
==
self
.
_current_default_model
:
if
model_name
==
self
.
_current_default_model
:
item
=
self
.
model_set_list
.
item
(
self
.
model_set_list
.
count
()
-
1
)
display_text
=
f
"{model_name}(默认)"
item
.
setText
(
f
"{model_name} [默认]"
)
self
.
model_set_list
.
addItem
(
display_text
)
# 更新统计信息
# 更新统计信息
self
.
_updateStats
()
self
.
_updateStats
()
...
@@ -1088,6 +1087,15 @@ class ModelSetPage(QtWidgets.QWidget):
...
@@ -1088,6 +1087,15 @@ class ModelSetPage(QtWidgets.QWidget):
import
traceback
import
traceback
traceback
.
print_exc
()
traceback
.
print_exc
()
def
showEvent
(
self
,
event
):
"""页面显示时自动加载模型(仅首次)"""
super
(
ModelSetPage
,
self
)
.
showEvent
(
event
)
# 只在首次显示时加载模型
if
not
self
.
_models_loaded
:
self
.
loadModelsFromConfig
()
self
.
_models_loaded
=
True
def
loadModelsFromConfig
(
self
):
def
loadModelsFromConfig
(
self
):
"""从配置文件和模型目录加载所有模型(委托给 handler)"""
"""从配置文件和模型目录加载所有模型(委托给 handler)"""
try
:
try
:
...
...
widgets/modelpage/training_page.py
View file @
21a67d1a
...
@@ -1008,39 +1008,112 @@ class TrainingPage(QtWidgets.QWidget):
...
@@ -1008,39 +1008,112 @@ class TrainingPage(QtWidgets.QWidget):
return
models
return
models
def
_scanDetectionModelDirectory
(
self
,
project_root
):
def
_scanDetectionModelDirectory
(
self
,
project_root
):
"""扫描 train_model 目录获取所有测试模型文件"""
"""扫描 train_model 目录获取所有测试模型文件
(增强版:按优先级选择模型)
"""
models
=
[]
models
=
[]
try
:
try
:
import
yaml
# 🔥 修改:从 train_model 文件夹读取
# 🔥 修改:从 train_model 文件夹读取
model_dir
=
Path
(
project_root
)
/
"database"
/
"model"
/
"train_model"
model_dir
=
Path
(
project_root
)
/
"database"
/
"model"
/
"train_model"
if
not
model_dir
.
exists
():
if
not
model_dir
.
exists
():
return
models
return
models
# 遍历所有子目录
# 按数字排序子目录(降序,最新的在前)
for
subdir
in
sorted
(
model_dir
.
iterdir
()):
sorted_subdirs
=
sorted
(
if
subdir
.
is_dir
():
[
d
for
d
in
model_dir
.
iterdir
()
if
d
.
is_dir
()
and
d
.
name
.
isdigit
()],
# 优先查找 .dat 文件
key
=
lambda
x
:
int
(
x
.
name
),
for
model_file
in
sorted
(
subdir
.
glob
(
"*.dat"
)):
reverse
=
True
models
.
append
({
)
'name'
:
f
"{subdir.name}-{model_file.stem}"
,
'path'
:
str
(
model_file
),
for
subdir
in
sorted_subdirs
:
'source'
:
'train_model'
,
# 尝试读取 config.yaml 获取详细信息
'format'
:
'dat'
config_file
=
subdir
/
"config.yaml"
})
model_config
=
None
if
config_file
.
exists
():
try
:
with
open
(
config_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
model_config
=
yaml
.
safe_load
(
f
)
except
Exception
as
e
:
pass
# 检查是否有weights子目录(优先检查train/weights,然后weights)
train_weights_dir
=
subdir
/
"train"
/
"weights"
weights_dir
=
subdir
/
"weights"
if
train_weights_dir
.
exists
():
search_dir
=
train_weights_dir
elif
weights_dir
.
exists
():
search_dir
=
weights_dir
else
:
search_dir
=
subdir
# 按优先级查找模型文件:best > last > epoch1
# 支持的扩展名:.dat, .pt, .template_*, 无扩展名
selected_model
=
None
# 优先级1: best模型
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
file
.
name
.
startswith
(
'best.'
)
and
not
file
.
name
.
endswith
(
'.pt'
):
selected_model
=
file
break
# 优先级2: last模型(如果没有best)
if
not
selected_model
:
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
file
.
name
.
startswith
(
'last.'
)
and
not
file
.
name
.
endswith
(
'.pt'
):
selected_model
=
file
break
# 优先级3: epoch1模型(如果没有best和last)
if
not
selected_model
:
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
file
.
name
.
startswith
(
'epoch1.'
)
and
not
file
.
name
.
endswith
(
'.pt'
):
selected_model
=
file
break
# 如果都没找到,尝试查找任何非.pt文件
if
not
selected_model
:
for
file
in
search_dir
.
iterdir
():
if
file
.
is_file
()
and
not
file
.
name
.
endswith
(
'.pt'
)
and
not
file
.
name
.
endswith
(
'.txt'
)
and
not
file
.
name
.
endswith
(
'.yaml'
):
selected_model
=
file
break
# 如果找到了模型文件,添加到列表
if
selected_model
:
# 从 config.yaml 获取信息,或使用默认值
if
model_config
:
model_name
=
model_config
.
get
(
'name'
,
f
"训练模型-{subdir.name}"
)
description
=
model_config
.
get
(
'description'
,
''
)
training_date
=
model_config
.
get
(
'training_date'
,
''
)
else
:
model_name
=
selected_model
.
stem
# 使用文件名(不含扩展名)作为模型名
description
=
f
"来自目录 {subdir.name}"
training_date
=
''
# 获取文件格式
file_ext
=
selected_model
.
suffix
.
lstrip
(
'.'
)
if
not
file_ext
:
# 处理无扩展名的情况(如 best.template_6543)
if
'.'
in
selected_model
.
name
:
file_ext
=
selected_model
.
name
.
split
(
'.'
)[
-
1
]
else
:
file_ext
=
'unknown'
# 然后查找 .pt 文件
for
model_file
in
sorted
(
subdir
.
glob
(
"*.pt"
)):
models
.
append
({
models
.
append
({
'name'
:
f
"{subdir.name}-{model_file.stem}"
,
'name'
:
model_name
,
'path'
:
str
(
model_file
),
'path'
:
str
(
selected_model
),
'source'
:
'train_model'
,
'source'
:
'train_model'
,
'format'
:
'pt'
'format'
:
file_ext
,
'description'
:
description
,
'training_date'
:
training_date
,
'file_name'
:
selected_model
.
name
})
})
except
Exception
as
e
:
except
Exception
as
e
:
pass
import
traceback
traceback
.
print_exc
()
return
models
return
models
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment