Commit 21a67d1a by Yuhaibo

1

parent e491a327
......@@ -104,11 +104,8 @@ class ModelSetPage(QtWidgets.QWidget):
self._connectSignals()
self._setupShortcuts()
# 移除自动加载,改为手动加载模式
# 用户可以通过以下方式加载模型:
# 1. 按 F5 刷新
# 2. 右键点击空白区域刷新
# 3. 点击刷新按钮(如果有)
# 标记是否已加载过模型
self._models_loaded = False
def _initUI(self):
"""初始化UI - 简约labelme风格(三栏充实布局版)"""
......@@ -765,12 +762,14 @@ class ModelSetPage(QtWidgets.QWidget):
# 重新添加所有模型
for model_name in self._model_params.keys():
self.model_set_list.addItem(model_name)
# 构建显示文本
display_text = model_name
# 如果是默认模型,添加标记
# 如果是默认模型,添加标记(使用统一格式)
if model_name == self._current_default_model:
item = self.model_set_list.item(self.model_set_list.count() - 1)
item.setText(f"{model_name} [默认]")
display_text = f"{model_name}(默认)"
self.model_set_list.addItem(display_text)
# 更新统计信息
self._updateStats()
......@@ -1088,6 +1087,15 @@ class ModelSetPage(QtWidgets.QWidget):
import traceback
traceback.print_exc()
def showEvent(self, event):
"""页面显示时自动加载模型(仅首次)"""
super(ModelSetPage, self).showEvent(event)
# 只在首次显示时加载模型
if not self._models_loaded:
self.loadModelsFromConfig()
self._models_loaded = True
def loadModelsFromConfig(self):
"""从配置文件和模型目录加载所有模型(委托给 handler)"""
try:
......
......@@ -1008,39 +1008,112 @@ class TrainingPage(QtWidgets.QWidget):
return models
def _scanDetectionModelDirectory(self, project_root):
"""扫描 train_model 目录获取所有测试模型文件"""
"""扫描 train_model 目录获取所有测试模型文件(增强版:按优先级选择模型)"""
models = []
try:
import yaml
# 🔥 修改:从 train_model 文件夹读取
model_dir = Path(project_root) / "database" / "model" / "train_model"
if not model_dir.exists():
return models
# 遍历所有子目录
for subdir in sorted(model_dir.iterdir()):
if subdir.is_dir():
# 优先查找 .dat 文件
for model_file in sorted(subdir.glob("*.dat")):
models.append({
'name': f"{subdir.name}-{model_file.stem}",
'path': str(model_file),
'source': 'train_model',
'format': 'dat'
})
# 按数字排序子目录(降序,最新的在前)
sorted_subdirs = sorted(
[d for d in model_dir.iterdir() if d.is_dir() and d.name.isdigit()],
key=lambda x: int(x.name),
reverse=True
)
for subdir in sorted_subdirs:
# 尝试读取 config.yaml 获取详细信息
config_file = subdir / "config.yaml"
model_config = None
if config_file.exists():
try:
with open(config_file, 'r', encoding='utf-8') as f:
model_config = yaml.safe_load(f)
except Exception as e:
pass
# 检查是否有weights子目录(优先检查train/weights,然后weights)
train_weights_dir = subdir / "train" / "weights"
weights_dir = subdir / "weights"
if train_weights_dir.exists():
search_dir = train_weights_dir
elif weights_dir.exists():
search_dir = weights_dir
else:
search_dir = subdir
# 按优先级查找模型文件:best > last > epoch1
# 支持的扩展名:.dat, .pt, .template_*, 无扩展名
selected_model = None
# 优先级1: best模型
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('best.') and not file.name.endswith('.pt'):
selected_model = file
break
# 优先级2: last模型(如果没有best)
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('last.') and not file.name.endswith('.pt'):
selected_model = file
break
# 优先级3: epoch1模型(如果没有best和last)
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('epoch1.') and not file.name.endswith('.pt'):
selected_model = file
break
# 如果都没找到,尝试查找任何非.pt文件
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and not file.name.endswith('.pt') and not file.name.endswith('.txt') and not file.name.endswith('.yaml'):
selected_model = file
break
# 如果找到了模型文件,添加到列表
if selected_model:
# 从 config.yaml 获取信息,或使用默认值
if model_config:
model_name = model_config.get('name', f"训练模型-{subdir.name}")
description = model_config.get('description', '')
training_date = model_config.get('training_date', '')
else:
model_name = selected_model.stem # 使用文件名(不含扩展名)作为模型名
description = f"来自目录 {subdir.name}"
training_date = ''
# 然后查找 .pt 文件
for model_file in sorted(subdir.glob("*.pt")):
models.append({
'name': f"{subdir.name}-{model_file.stem}",
'path': str(model_file),
'source': 'train_model',
'format': 'pt'
})
# 获取文件格式
file_ext = selected_model.suffix.lstrip('.')
if not file_ext:
# 处理无扩展名的情况(如 best.template_6543)
if '.' in selected_model.name:
file_ext = selected_model.name.split('.')[-1]
else:
file_ext = 'unknown'
models.append({
'name': model_name,
'path': str(selected_model),
'source': 'train_model',
'format': file_ext,
'description': description,
'training_date': training_date,
'file_name': selected_model.name
})
except Exception as e:
pass
import traceback
traceback.print_exc()
return models
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment