Commit 21a67d1a by Yuhaibo

1

parent e491a327
......@@ -301,11 +301,9 @@ class ModelSetHandler:
if reply == QtWidgets.QMessageBox.Yes:
# 调用 ModelSetPage 的删除方法
# deleteModel会触发deleteModelDataRequested信号,由deleteModelData处理
# deleteModelData会自动刷新训练页面
self.modelSetPage.deleteModel(model_name)
# 更新测试页面的模型列表
if hasattr(self, 'testModelPage'):
self.testModelPage.loadModelsFromModelSetPage(self.modelSetPage)
else:
QtWidgets.QMessageBox.warning(self, "错误", "模型集页面未初始化")
......@@ -513,7 +511,7 @@ class ModelSetHandler:
return models
def _scanModelDirectory(self):
"""扫描模型目录获取所有模型文件"""
"""扫描模型目录获取所有模型文件(增强版:支持config.yaml和按优先级选择模型)"""
models = []
try:
......@@ -521,45 +519,140 @@ class ModelSetHandler:
current_dir = Path(__file__).parent.parent.parent
model_dir = current_dir / "database" / "model" / "train_model"
print(f"[模型扫描] 扫描目录: {model_dir}")
if not model_dir.exists():
print(f"[模型扫描] 目录不存在")
return models
all_items = list(model_dir.iterdir())
print(f"[模型扫描] 目录存在: {model_dir.exists()}")
# 按数字排序子目录(降序,最新的在前)
all_subdirs = [d for d in model_dir.iterdir() if d.is_dir()]
print(f"[模型扫描] 找到子目录数量: {len(all_subdirs)}")
print(f"[模型扫描] 子目录列表: {[d.name for d in all_subdirs]}")
digit_subdirs = [d for d in all_subdirs if d.name.isdigit()]
print(f"[模型扫描] 数字子目录数量: {len(digit_subdirs)}")
sorted_subdirs = sorted(model_dir.iterdir(), key=lambda x: x.name if x.is_dir() else '')
sorted_subdirs = sorted(digit_subdirs, key=lambda x: int(x.name), reverse=True)
print(f"[模型扫描] 数字子目录: {[d.name for d in sorted_subdirs]}")
for subdir in sorted_subdirs:
if subdir.is_dir():
subdir_files = list(subdir.iterdir())
dat_files = list(subdir.glob("*.dat"))
for model_file in sorted(dat_files):
model_info = {
'name': f"模型-{subdir.name}-{model_file.stem}",
'path': str(model_file),
'subdir': subdir.name,
'source': 'scan',
'format': 'dat'
}
models.append(model_info)
print(f"[模型扫描] 处理子目录: {subdir.name}")
# 尝试读取 config.yaml 获取详细信息
config_file = subdir / "config.yaml"
model_config = None
if config_file.exists():
try:
with open(config_file, 'r', encoding='utf-8') as f:
model_config = yaml.safe_load(f)
except Exception as e:
print(f"[模型扫描] 读取config.yaml失败: {e}")
# 检查是否有weights子目录(优先检查train/weights,然后weights)
train_weights_dir = subdir / "train" / "weights"
weights_dir = subdir / "weights"
if train_weights_dir.exists():
search_dir = train_weights_dir
print(f"[模型扫描] 找到train/weights目录: {search_dir}")
elif weights_dir.exists():
search_dir = weights_dir
print(f"[模型扫描] 找到weights目录: {search_dir}")
else:
search_dir = subdir
print(f"[模型扫描] 使用根目录: {search_dir}")
print(f"[模型扫描] 搜索目录: {search_dir}")
# 按优先级查找模型文件:best > last > epoch1
# 支持的扩展名:.dat, .pt, .template_*, 无扩展名
selected_model = None
# 优先级1: best模型
for pattern in ['best.*.dat', 'best.*.pt', 'best.template_*', 'best.*']:
if selected_model:
break
for file in search_dir.iterdir():
if file.is_file():
# 检查文件名是否匹配模式
if file.name.startswith('best.') and not file.name.endswith('.pt'):
selected_model = file
print(f"[模型扫描] 找到best模型: {file.name}")
break
# 优先级2: last模型(如果没有best)
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('last.') and not file.name.endswith('.pt'):
selected_model = file
print(f"[模型扫描] 找到last模型: {file.name}")
break
# 优先级3: epoch1模型(如果没有best和last)
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('epoch1.') and not file.name.endswith('.pt'):
selected_model = file
print(f"[模型扫描] 找到epoch1模型: {file.name}")
break
# 如果都没找到,尝试查找任何非.pt文件
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and not file.name.endswith('.pt') and not file.name.endswith('.txt') and not file.name.endswith('.yaml'):
selected_model = file
print(f"[模型扫描] 找到其他模型: {file.name}")
break
# 如果找到了模型文件,添加到列表
if selected_model:
# 从 config.yaml 获取信息,或使用默认值
if model_config:
model_name = model_config.get('name', f"训练模型-{subdir.name}")
description = model_config.get('description', '')
training_date = model_config.get('training_date', '')
epochs = model_config.get('epochs', '')
else:
model_name = selected_model.stem # 使用文件名(不含扩展名)作为模型名
description = f"来自目录 {subdir.name}"
training_date = ''
epochs = ''
pt_files = list(subdir.glob("*.pt"))
# 获取文件格式
file_ext = selected_model.suffix.lstrip('.')
if not file_ext:
# 处理无扩展名的情况(如 best.template_6543)
if '.' in selected_model.name:
file_ext = selected_model.name.split('.')[-1]
else:
file_ext = 'unknown'
for model_file in sorted(pt_files):
model_info = {
'name': f"模型-{subdir.name}-{model_file.stem}",
'path': str(model_file),
'subdir': subdir.name,
'source': 'scan',
'format': 'pt'
}
models.append(model_info)
model_info = {
'name': model_name,
'path': str(selected_model),
'subdir': subdir.name,
'source': 'train_model',
'format': file_ext,
'description': description,
'training_date': training_date,
'epochs': epochs,
'file_name': selected_model.name
}
models.append(model_info)
print(f"[模型扫描] 添加模型: {model_name} ({selected_model.name})")
else:
print(f"[模型扫描] 子目录 {subdir.name} 中未找到有效模型")
except Exception as e:
import traceback
traceback.print_exc()
print(f"[模型扫描] 扫描异常: {e}")
print(f"[模型扫描] 总共找到 {len(models)} 个模型")
return models
def _mergeModelInfo(self, channel_models, scanned_models):
......@@ -581,9 +674,19 @@ class ModelSetHandler:
all_models.append(model)
seen_paths.add(path)
# 确保有一个默认模型
has_default = any(model.get('is_default', False) for model in all_models)
if not has_default and len(all_models) > 0:
# 确保只有一个默认模型
default_model_index = -1
for i, model in enumerate(all_models):
if model.get('is_default', False):
if default_model_index == -1:
# 第一个默认模型,保留
default_model_index = i
else:
# 已经有默认模型了,取消这个模型的默认标记
model['is_default'] = False
# 如果没有默认模型,将第一个模型设为默认
if default_model_index == -1 and len(all_models) > 0:
all_models[0]['is_default'] = True
return all_models
......@@ -803,18 +906,66 @@ class ModelSetHandler:
# 从参数中删除
if model_name in self.modelSetPage._model_params:
model_params = self.modelSetPage._model_params[model_name]
model_path = model_params.get('path', '')
# 删除模型文件和所在目录
if model_path and os.path.exists(model_path):
try:
import shutil
# 获取模型所在的目录(train_model/{数字ID}/)
model_dir = os.path.dirname(model_path)
# 检查是否是train_model目录下的子目录
if 'train_model' in model_dir:
# 删除整个模型目录
if os.path.exists(model_dir):
shutil.rmtree(model_dir)
print(f"[删除模型] 已删除模型目录: {model_dir}")
except Exception as delete_error:
print(f"[删除模型] 删除模型文件失败: {delete_error}")
del self.modelSetPage._model_params[model_name]
# 从配置文件中删除模型配置
self._removeModelFromConfig(model_name)
# 刷新训练页面的模型测试下拉框
self._refreshTrainingPageModelList()
return True
return False
except Exception as e:
print(f"[删除模型] 删除失败: {e}")
import traceback
traceback.print_exc()
return False
def _refreshTrainingPageModelList(self):
"""刷新训练页面的模型测试下拉框"""
try:
# 通过主窗口访问训练页面
if hasattr(self, 'main_window'):
# 尝试多种可能的属性名
training_page = None
if hasattr(self.main_window, 'trainingPage'):
training_page = self.main_window.trainingPage
elif hasattr(self.main_window, 'training_page'):
training_page = self.main_window.training_page
elif hasattr(self.main_window, 'testModelPage'):
training_page = self.main_window.testModelPage
if training_page and hasattr(training_page, '_loadTestModelOptions'):
training_page._loadTestModelOptions()
print("[删除模型] 已刷新训练页面的模型列表")
else:
print("[删除模型] 未找到训练页面或_loadTestModelOptions方法")
except Exception as e:
print(f"[删除模型] 刷新训练页面失败: {e}")
def getAllModelParams(self):
"""获取所有模型参数"""
if hasattr(self, 'modelSetPage') and self.modelSetPage:
......
......@@ -394,38 +394,27 @@ class ModelTrainingHandler(ModelTestHandler):
self._appendLog(" 模型升级成功完成!\n")
self._appendLog("="*70 + "\n")
# 升级完成后进行pt转dat转换
# 获取weights目录和转换结果
weights_dir = None
converted_files = []
if self.training_worker and hasattr(self.training_worker, "training_report"):
weights_dir = self.training_worker.training_report.get("weights_dir")
converted_files = self.training_worker.training_report.get("converted_dat_files", [])
converted_files = []
if self.current_exp_name or weights_dir:
self._appendLog("\n 正在将升级结果转换为dat格式...\n")
try:
# 使用新的立即转换和清理方法
if weights_dir:
converted_files = self._convertPtToDatAndCleanup(weights_dir)
else:
converted_files = self._convertTrainingmission_resultsToDat(self.current_exp_name, weights_dir)
if converted_files:
self._appendLog(f" 已将训练结果转换为dat格式: {len(converted_files)}个文件\n")
for f in converted_files:
self._appendLog(f" - {f}\n")
else:
self._appendLog("转换完成但没有文件被转换\n")
except Exception as convert_error:
self._appendLog(f"[WARNING] 模型转换失败: {str(convert_error)}\n")
import traceback
traceback.print_exc()
# 保存训练日志到weights目录
# 显示转换结果(转换已在TrainingWorker中完成)
if converted_files:
self._appendLog(f"\n 模型已转换为dat格式: {len(converted_files)}个文件\n")
for f in converted_files:
self._appendLog(f" - {os.path.basename(f)}\n")
# 保存训练日志到weights目录
if weights_dir:
self._appendLog("\n 正在保存训练日志...\n")
log_file = self._saveTrainingLogToWeightsDir(self.current_exp_name, weights_dir)
if log_file:
self._appendLog(f" 训练日志已保存: {os.path.basename(log_file)}\n")
else:
self._appendLog(" [WARNING] 训练日志保存失败\n")
# 将转换结果写入 training_report.json
try:
......@@ -508,28 +497,27 @@ class ModelTrainingHandler(ModelTestHandler):
self._appendLog("训练已暂停\n")
self._appendLog("="*70 + "\n")
# 转换已产生的模型为 dat 格式
# 获取weights目录和转换结果(转换已在TrainingWorker中完成)
converted_files = []
weights_dir = None
if self.current_exp_name:
train_root = get_train_dir()
weights_dir = os.path.join(train_root, "runs", "train", self.current_exp_name, "weights")
# 转换模型文件 - 使用新的立即转换和清理方法
self._appendLog("\n正在转换已产生的模型为dat格式...\n")
converted_files = self._convertPtToDatAndCleanup(weights_dir)
if converted_files:
self._appendLog(f"成功转换 {len(converted_files)} 个模型到.dat格式\n")
for f in converted_files:
self._appendLog(f" - {f}\n")
else:
self._appendLog("[WARNING] 未找到可转换的模型文件\n")
# 保存训练日志到weights目录
if self.training_worker and hasattr(self.training_worker, "training_report"):
weights_dir = self.training_worker.training_report.get("weights_dir")
converted_files = self.training_worker.training_report.get("converted_dat_files", [])
# 显示转换结果
if converted_files:
self._appendLog(f"\n模型已转换为dat格式: {len(converted_files)}个文件\n")
for f in converted_files:
self._appendLog(f" - {os.path.basename(f)}\n")
# 保存训练日志到weights目录
if weights_dir:
self._appendLog("\n正在保存训练日志...\n")
log_file = self._saveTrainingLogToWeightsDir(self.current_exp_name, weights_dir)
if log_file:
self._appendLog(f"训练日志已保存: {os.path.basename(log_file)}\n")
else:
self._appendLog("[WARNING] 训练日志保存失败\n")
# 用户停止训练时也要同步模型到模型集
self._appendLog("\n正在将模型同步到模型集管理...\n")
......@@ -2208,51 +2196,79 @@ class ModelTrainingHandler(ModelTestHandler):
traceback.print_exc()
def _saveModelToConfig(self, model_name, model_params):
"""保存模型到配置文件"""
"""保存模型配置文件(模型已经在train_model目录中)"""
try:
from pathlib import Path
import yaml
# 获取项目根目录
project_root = get_project_root()
self._appendLog(f"\n开始保存模型配置...\n")
self._appendLog(f" 模型名称: {model_name}\n")
self._appendLog(f" 模型路径: {model_params['path']}\n")
# 模型目录结构:database/model/train_model/{model_name}/
train_model_dir = os.path.join(project_root, 'database', 'model', 'train_model')
os.makedirs(train_model_dir, exist_ok=True)
# 获取模型所在目录(应该已经在train_model/{数字ID}/weights/中)
model_path = model_params['path']
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
# 查找下一个可用的数字目录
existing_dirs = []
for item in os.listdir(train_model_dir):
item_path = os.path.join(train_model_dir, item)
if os.path.isdir(item_path) and item.isdigit():
existing_dirs.append(int(item))
# 获取weights目录
weights_dir = os.path.dirname(model_path)
# 获取模型ID目录(weights的父目录)
model_dir = os.path.dirname(weights_dir)
next_id = max(existing_dirs) + 1 if existing_dirs else 1
model_dir = os.path.join(train_model_dir, str(next_id))
os.makedirs(model_dir, exist_ok=True)
self._appendLog(f" 模型目录: {model_dir}\n")
# 复制模型文件到新目录,保持原有的命名规则
source_model_path = model_params['path']
source_filename = os.path.basename(source_model_path)
dest_model_path = os.path.join(model_dir, source_filename)
import shutil
shutil.copy2(source_model_path, dest_model_path)
# 更新模型参数中的路径
model_params['path'] = dest_model_path
# 查找可用的模型文件
model_files_found = []
# 1. 查找best模型
for filename in os.listdir(weights_dir):
if filename.startswith('best.') and not filename.endswith('.pt'):
model_files_found.append(f"best模型: {filename}")
# 更新路径指向best模型
if 'best' in filename:
model_params['path'] = os.path.join(weights_dir, filename)
break
# 2. 查找last模型
for filename in os.listdir(weights_dir):
if filename.startswith('last.') and not filename.endswith('.pt'):
model_files_found.append(f"last模型: {filename}")
break
# 3. 查找epoch1模型
for filename in os.listdir(weights_dir):
if filename.startswith('epoch1.') and not filename.endswith('.pt'):
model_files_found.append(f"第一轮模型: {filename}")
break
# 4. 查找训练日志
log_files_found = []
for filename in os.listdir(weights_dir):
if filename.startswith('training_log_') and filename.endswith('.txt'):
log_files_found.append(filename)
# 保存模型配置到 YAML 文件
config_file = os.path.join(model_dir, 'config.yaml')
self._appendLog(f" 保存配置文件: {config_file}\n")
with open(config_file, 'w', encoding='utf-8') as f:
yaml.dump(model_params, f, allow_unicode=True, default_flow_style=False)
self._appendLog(f" 模型已保存到: {model_dir}\n")
# 输出总结信息
self._appendLog(f"\n✅ 模型配置已保存: {model_dir}\n")
self._appendLog(f" 找到的文件:\n")
for info in model_files_found:
self._appendLog(f" - {info}\n")
if log_files_found:
for log_file in log_files_found:
self._appendLog(f" - 训练日志: {log_file}\n")
self._appendLog(f" - 配置文件: config.yaml\n")
except Exception as e:
self._appendLog(f"[ERROR] 保存模型配置失败: {str(e)}\n")
self._appendLog(f"[ERROR] 保存模型配置失败: {str(e)}\n")
import traceback
traceback.print_exc()
self._appendLog(f"\n完整错误信息:\n")
self._appendLog(traceback.format_exc())
def _refreshModelTestPage(self):
"""刷新模型测试页面的模型列表"""
......@@ -2619,70 +2635,6 @@ class ModelTrainingHandler(ModelTestHandler):
import traceback
traceback.print_exc()
def _addTrainedModelToModelSet(self, converted_files, weights_dir):
"""将训练好的模型添加到模型集配置"""
try:
if not converted_files:
print("[信息] 没有转换的模型文件,跳过添加到模型集")
return
print(f"[信息] 正在将 {len(converted_files)} 个训练模型添加到模型集...")
# 获取项目根目录
project_root = get_project_root()
# 为每个转换的模型文件创建模型信息
for dat_file in converted_files:
try:
# 获取模型文件信息
model_path = Path(dat_file)
model_name = f"训练模型-{model_path.parent.name}-{model_path.stem}"
# 检查模型是否已存在
if hasattr(self, 'main_window') and hasattr(self.main_window, 'modelSetPage'):
existing_models = self.main_window.modelSetPage.getAllModels()
if any(model_name in model for model in existing_models):
print(f"[信息] 模型 {model_name} 已存在,跳过添加")
continue
# 创建模型参数
model_params = {
'name': model_name,
'type': 'YOLOv11',
'path': str(dat_file),
'config_path': '',
'description': f'训练于 {weights_dir}' if weights_dir else '自动训练生成的模型',
'size': self._getFileSize(str(dat_file)),
'classes': 3, # liquid, foam, air
'input': '640x640',
'confidence': 0.5,
'iou': 0.45,
'device': 'CUDA:0 (GPU)',
'batch_size': 16,
'blur_training': 100,
'epochs': 300,
'workers': 8,
'model_type': '训练模型',
'source': 'training'
}
# 添加到模型集页面
if hasattr(self, 'main_window') and hasattr(self.main_window, 'modelSetPage'):
self.main_window.modelSetPage._model_params[model_name] = model_params
self.main_window.modelSetPage.addModelToList(model_name)
print(f"[信息] 已添加模型: {model_name}")
except Exception as model_error:
print(f"[错误] 添加模型 {dat_file} 失败: {model_error}")
continue
print("[信息] 训练模型添加到模型集完成")
except Exception as e:
print(f"[错误] 添加训练模型到模型集失败: {e}")
import traceback
traceback.print_exc()
def _getFileSize(self, file_path):
"""获取文件大小(格式化字符串)"""
try:
......
......@@ -39,6 +39,19 @@ except (ImportError, ValueError):
sys.path.insert(0, str(project_root))
from database.config import get_project_root, get_temp_models_dir, get_train_dir
# 导入模型转换工具
try:
from .tools.convert_pt_to_dat import FileConverter as PtToDatConverter
except (ImportError, ValueError):
try:
from handlers.modelpage.tools.convert_pt_to_dat import FileConverter as PtToDatConverter
except ImportError:
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from handlers.modelpage.tools.convert_pt_to_dat import FileConverter as PtToDatConverter
MODEL_FILE_SIGNATURE = b'LDS_MODEL_FILE'
MODEL_FILE_VERSION = 1
MODEL_ENCRYPTION_KEY = "liquid_detection_system_2024"
......@@ -693,6 +706,27 @@ class TrainingWorker(QThread):
if device_str.lower() in ['cpu', '-1']:
workers = 0 # CPU模式下禁用多线程数据加载
# 获取下一个可用的模型ID并创建目录
project_root = get_project_root()
train_model_dir = os.path.join(project_root, 'database', 'model', 'train_model')
os.makedirs(train_model_dir, exist_ok=True)
# 查找下一个可用的数字ID
existing_dirs = []
for item in os.listdir(train_model_dir):
item_path = os.path.join(train_model_dir, item)
if os.path.isdir(item_path) and item.isdigit():
existing_dirs.append(int(item))
next_id = max(existing_dirs) + 1 if existing_dirs else 1
model_output_dir = os.path.join(train_model_dir, str(next_id))
# 预先设置weights_dir(YOLO会在project目录下创建weights子目录)
weights_dir = os.path.join(model_output_dir, "weights")
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"模型将直接保存到: {model_output_dir}\n")
# 开始训练
try:
mission_results = model.train(
......@@ -705,8 +739,8 @@ class TrainingWorker(QThread):
optimizer=self.training_params['optimizer'],
close_mosaic=self.training_params['close_mosaic'],
resume=self.training_params['resume'],
project='database/train/runs/train',
name=self.training_params['exp_name'],
project=model_output_dir,
name='', # 空名称,直接使用project目录
single_cls=self.training_params['single_cls'],
cache=False,
pretrained=self.training_params['pretrained'],
......@@ -726,6 +760,50 @@ class TrainingWorker(QThread):
self.log_output.emit("等待当前epoch完成并保存模型...\n")
time.sleep(2) # 给YOLO时间完成保存
training_success = True # 标记为成功,因为这是用户主动停止
# 获取实际保存目录并转换模型
try:
save_dir = getattr(getattr(model, "trainer", None), "save_dir", None)
self.log_output.emit(f"\n[调试] model.trainer 存在: {hasattr(model, 'trainer')}\n")
self.log_output.emit(f"[调试] save_dir 值: {save_dir}\n")
if save_dir:
save_dir_abs = os.path.abspath(str(save_dir))
weights_dir = os.path.abspath(os.path.join(save_dir_abs, "weights"))
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"[调试] 实际保存目录: {save_dir_abs}\n")
self.log_output.emit(f"[调试] 实际权重目录: {weights_dir}\n")
# 立即转换PT文件为DAT格式并删除PT文件
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
else:
# 备用方案:使用预设的model_output_dir
self.log_output.emit("\n[WARNING] 无法从trainer获取保存目录,使用预设目录\n")
if 'model_output_dir' in locals():
# 查找实际的weights目录(可能在train子目录下)
possible_weights_dirs = [
os.path.join(model_output_dir, "train", "weights"),
os.path.join(model_output_dir, "weights")
]
for possible_dir in possible_weights_dirs:
if os.path.exists(possible_dir):
weights_dir = possible_dir
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"[调试] 找到权重目录: {weights_dir}\n")
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
break
else:
self.log_output.emit(f"[ERROR] 未找到权重目录,跳过转换\n")
else:
self.log_output.emit(f"[ERROR] model_output_dir 未定义,跳过转换\n")
except Exception as convert_err:
self.log_output.emit(f"\n[ERROR] 转换过程出错: {convert_err}\n")
import traceback
self.log_output.emit(traceback.format_exc())
break # 跳出重试循环
except Exception as e:
......@@ -750,6 +828,50 @@ class TrainingWorker(QThread):
self.log_output.emit("等待当前epoch完成并保存模型...\n")
time.sleep(2) # 给YOLO时间完成保存
training_success = True
# 获取实际保存目录并转换模型
try:
save_dir = getattr(getattr(model, "trainer", None), "save_dir", None)
self.log_output.emit(f"\n[调试] model.trainer 存在: {hasattr(model, 'trainer')}\n")
self.log_output.emit(f"[调试] save_dir 值: {save_dir}\n")
if save_dir:
save_dir_abs = os.path.abspath(str(save_dir))
weights_dir = os.path.abspath(os.path.join(save_dir_abs, "weights"))
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"[调试] 实际保存目录: {save_dir_abs}\n")
self.log_output.emit(f"[调试] 实际权重目录: {weights_dir}\n")
# 立即转换PT文件为DAT格式并删除PT文件
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
else:
# 备用方案:使用预设的model_output_dir
self.log_output.emit("\n[WARNING] 无法从trainer获取保存目录,使用预设目录\n")
if 'model_output_dir' in locals():
# 查找实际的weights目录(可能在train子目录下)
possible_weights_dirs = [
os.path.join(model_output_dir, "train", "weights"),
os.path.join(model_output_dir, "weights")
]
for possible_dir in possible_weights_dirs:
if os.path.exists(possible_dir):
weights_dir = possible_dir
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"[调试] 找到权重目录: {weights_dir}\n")
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
break
else:
self.log_output.emit(f"[ERROR] 未找到权重目录,跳过转换\n")
else:
self.log_output.emit(f"[ERROR] model_output_dir 未定义,跳过转换\n")
except Exception as convert_err:
self.log_output.emit(f"\n[ERROR] 转换过程出错: {convert_err}\n")
import traceback
self.log_output.emit(traceback.format_exc())
break
# 训练成功
......@@ -763,11 +885,18 @@ class TrainingWorker(QThread):
weights_dir = os.path.abspath(os.path.join(save_dir_abs, "weights"))
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"\n[调试] 实际保存目录: {save_dir_abs}\n")
self.log_output.emit(f"[调试] 实际权重目录: {weights_dir}\n")
# 立即转换PT文件为DAT格式并删除PT文件
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
except:
pass
else:
self.log_output.emit("\n[WARNING] 无法获取模型保存目录,跳过转换\n")
except Exception as convert_err:
self.log_output.emit(f"\n[ERROR] 转换过程出错: {convert_err}\n")
import traceback
self.log_output.emit(traceback.format_exc())
break # 跳出重试循环
except RuntimeError as runtime_error:
......@@ -1050,6 +1179,88 @@ class TrainingWorker(QThread):
except Exception as e:
pass
def _convertPtToDatAndCleanup(self, weights_dir):
"""
转换PT文件为DAT格式并删除原始PT文件
Args:
weights_dir: 权重目录路径
"""
try:
if not os.path.exists(weights_dir):
self.log_output.emit(f"[转换] 权重目录不存在: {weights_dir}\n")
return
self.log_output.emit(f"[转换] 开始扫描权重目录: {weights_dir}\n")
# 创建转换器
converter = PtToDatConverter(key=MODEL_ENCRYPTION_KEY)
# 查找所有.pt文件
pt_files = []
for filename in os.listdir(weights_dir):
if filename.endswith('.pt'):
pt_file_path = os.path.join(weights_dir, filename)
pt_files.append(pt_file_path)
if not pt_files:
self.log_output.emit(f"[转换] 未找到.pt文件\n")
return
self.log_output.emit(f"[转换] 找到 {len(pt_files)} 个.pt文件\n")
# 转换每个.pt文件
converted_files = []
for pt_file in pt_files:
try:
filename = os.path.basename(pt_file)
self.log_output.emit(f"[转换] 正在转换: {filename}\n")
# 生成输出文件名(使用.dat扩展名)
exp_name = self.training_params.get('exp_name', '')
base_name = os.path.splitext(filename)[0]
if exp_name:
# 例如: best.pt -> best.template_1234.dat
output_filename = f"{base_name}.{exp_name}.dat"
else:
# 例如: best.pt -> best.dat
output_filename = f"{base_name}.dat"
output_path = os.path.join(weights_dir, output_filename)
self.log_output.emit(f"[转换] 输入: {pt_file}\n")
self.log_output.emit(f"[转换] 输出: {output_path}\n")
# 执行转换
converted_path = converter.convert_file(pt_file, output_path)
converted_files.append(converted_path)
self.log_output.emit(f"[转换] ✓ 已转换: {output_filename}\n")
# 删除原始.pt文件
try:
os.remove(pt_file)
self.log_output.emit(f"[转换] ✓ 已删除原始文件: {filename}\n")
except Exception as del_error:
self.log_output.emit(f"[转换] ⚠ 删除原始文件失败: {filename} - {del_error}\n")
except Exception as convert_error:
self.log_output.emit(f"[转换] ✗ 转换失败: {filename} - {convert_error}\n")
import traceback
self.log_output.emit(f"[转换] 错误详情:\n{traceback.format_exc()}\n")
continue
# 更新训练报告
self.training_report["converted_dat_files"] = converted_files
self.log_output.emit(f"\n[转换] 完成!共转换 {len(converted_files)} 个文件\n")
except Exception as e:
self.log_output.emit(f"[转换] 转换过程出错: {e}\n")
import traceback
traceback.print_exc()
def stop_training(self):
"""停止训练"""
self.is_running = False
......@@ -104,11 +104,8 @@ class ModelSetPage(QtWidgets.QWidget):
self._connectSignals()
self._setupShortcuts()
# 移除自动加载,改为手动加载模式
# 用户可以通过以下方式加载模型:
# 1. 按 F5 刷新
# 2. 右键点击空白区域刷新
# 3. 点击刷新按钮(如果有)
# 标记是否已加载过模型
self._models_loaded = False
def _initUI(self):
"""初始化UI - 简约labelme风格(三栏充实布局版)"""
......@@ -765,12 +762,14 @@ class ModelSetPage(QtWidgets.QWidget):
# 重新添加所有模型
for model_name in self._model_params.keys():
self.model_set_list.addItem(model_name)
# 构建显示文本
display_text = model_name
# 如果是默认模型,添加标记
# 如果是默认模型,添加标记(使用统一格式)
if model_name == self._current_default_model:
item = self.model_set_list.item(self.model_set_list.count() - 1)
item.setText(f"{model_name} [默认]")
display_text = f"{model_name}(默认)"
self.model_set_list.addItem(display_text)
# 更新统计信息
self._updateStats()
......@@ -1088,6 +1087,15 @@ class ModelSetPage(QtWidgets.QWidget):
import traceback
traceback.print_exc()
def showEvent(self, event):
"""页面显示时自动加载模型(仅首次)"""
super(ModelSetPage, self).showEvent(event)
# 只在首次显示时加载模型
if not self._models_loaded:
self.loadModelsFromConfig()
self._models_loaded = True
def loadModelsFromConfig(self):
"""从配置文件和模型目录加载所有模型(委托给 handler)"""
try:
......
......@@ -1008,39 +1008,112 @@ class TrainingPage(QtWidgets.QWidget):
return models
def _scanDetectionModelDirectory(self, project_root):
"""扫描 train_model 目录获取所有测试模型文件"""
"""扫描 train_model 目录获取所有测试模型文件(增强版:按优先级选择模型)"""
models = []
try:
import yaml
# 🔥 修改:从 train_model 文件夹读取
model_dir = Path(project_root) / "database" / "model" / "train_model"
if not model_dir.exists():
return models
# 遍历所有子目录
for subdir in sorted(model_dir.iterdir()):
if subdir.is_dir():
# 优先查找 .dat 文件
for model_file in sorted(subdir.glob("*.dat")):
models.append({
'name': f"{subdir.name}-{model_file.stem}",
'path': str(model_file),
'source': 'train_model',
'format': 'dat'
})
# 按数字排序子目录(降序,最新的在前)
sorted_subdirs = sorted(
[d for d in model_dir.iterdir() if d.is_dir() and d.name.isdigit()],
key=lambda x: int(x.name),
reverse=True
)
for subdir in sorted_subdirs:
# 尝试读取 config.yaml 获取详细信息
config_file = subdir / "config.yaml"
model_config = None
if config_file.exists():
try:
with open(config_file, 'r', encoding='utf-8') as f:
model_config = yaml.safe_load(f)
except Exception as e:
pass
# 检查是否有weights子目录(优先检查train/weights,然后weights)
train_weights_dir = subdir / "train" / "weights"
weights_dir = subdir / "weights"
if train_weights_dir.exists():
search_dir = train_weights_dir
elif weights_dir.exists():
search_dir = weights_dir
else:
search_dir = subdir
# 按优先级查找模型文件:best > last > epoch1
# 支持的扩展名:.dat, .pt, .template_*, 无扩展名
selected_model = None
# 优先级1: best模型
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('best.') and not file.name.endswith('.pt'):
selected_model = file
break
# 优先级2: last模型(如果没有best)
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('last.') and not file.name.endswith('.pt'):
selected_model = file
break
# 优先级3: epoch1模型(如果没有best和last)
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and file.name.startswith('epoch1.') and not file.name.endswith('.pt'):
selected_model = file
break
# 如果都没找到,尝试查找任何非.pt文件
if not selected_model:
for file in search_dir.iterdir():
if file.is_file() and not file.name.endswith('.pt') and not file.name.endswith('.txt') and not file.name.endswith('.yaml'):
selected_model = file
break
# 如果找到了模型文件,添加到列表
if selected_model:
# 从 config.yaml 获取信息,或使用默认值
if model_config:
model_name = model_config.get('name', f"训练模型-{subdir.name}")
description = model_config.get('description', '')
training_date = model_config.get('training_date', '')
else:
model_name = selected_model.stem # 使用文件名(不含扩展名)作为模型名
description = f"来自目录 {subdir.name}"
training_date = ''
# 然后查找 .pt 文件
for model_file in sorted(subdir.glob("*.pt")):
models.append({
'name': f"{subdir.name}-{model_file.stem}",
'path': str(model_file),
'source': 'train_model',
'format': 'pt'
})
# 获取文件格式
file_ext = selected_model.suffix.lstrip('.')
if not file_ext:
# 处理无扩展名的情况(如 best.template_6543)
if '.' in selected_model.name:
file_ext = selected_model.name.split('.')[-1]
else:
file_ext = 'unknown'
models.append({
'name': model_name,
'path': str(selected_model),
'source': 'train_model',
'format': file_ext,
'description': description,
'training_date': training_date,
'file_name': selected_model.name
})
except Exception as e:
pass
import traceback
traceback.print_exc()
return models
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment