Commit 7e155e88 by wangbing2

123

parent 1e35f3f8
......@@ -705,7 +705,7 @@ class ModelSetHandler:
if model_path.endswith('.pt'):
model_type = "PyTorch (YOLOv5/v8)"
elif model_path.endswith('.dat'):
model_type = "加密模型 (.dat)"
model_type = ".dat"
else:
model_type = "未知格式"
......
......@@ -136,6 +136,7 @@ class ModelTestThread(QThread):
# 读取标注数据
self.progress_updated.emit(45, "正在读取标注数据...")
import yaml
with open(annotation_file, 'r', encoding='utf-8') as f:
try:
annotation_data = yaml.safe_load(f)
......@@ -157,20 +158,31 @@ class ModelTestThread(QThread):
# 配置标注数据
self.progress_updated.emit(70, "正在配置检测参数...")
detection_engine.set_annotation_data(annotation_data)
# 从annotation_data中提取配置参数
boxes = annotation_data.get('boxes', [])
fixed_bottoms = annotation_data.get('fixed_bottoms', [])
fixed_tops = annotation_data.get('fixed_tops', [])
actual_heights = annotation_data.get('actual_heights', [20.0] * len(boxes))
detection_engine.configure(boxes, fixed_bottoms, fixed_tops, actual_heights)
# 执行检测
self.progress_updated.emit(80, "正在执行液位检测...")
detection_result = detection_engine.detect_liquid_level(test_frame)
detection_result = detection_engine.detect(test_frame)
if detection_result is None:
raise RuntimeError("检测结果为空")
self._detection_result = detection_result
# 转换检测结果格式以兼容现有代码
converted_result = self._convertDetectionResult(detection_result)
self._detection_result = converted_result
self.progress_updated.emit(90, "正在保存测试结果...")
self._saveImageTestResults(model_path, test_frame, detection_result, annotation_file)
# 通过handler调用保存方法
handler = self.test_params['handler']
handler._saveImageTestResults(model_path, test_frame, converted_result, annotation_file)
except Exception as e:
raise
......@@ -197,6 +209,59 @@ class ModelTestThread(QThread):
def get_detection_result(self):
"""获取检测结果"""
return self._detection_result
def _convertDetectionResult(self, detection_result):
"""
转换检测结果格式以兼容现有代码
Args:
detection_result: detect方法返回的结果格式
{
'liquid_line_positions': {
0: {'y': y坐标, 'height_mm': 高度毫米, 'height_px': 高度像素},
1: {...},
...
},
'success': bool
}
Returns:
dict: 转换后的结果格式,包含 liquid_level_mm 字段
"""
try:
if not detection_result or not detection_result.get('success', False):
return {'liquid_level_mm': 0.0, 'success': False}
liquid_positions = detection_result.get('liquid_line_positions', {})
if not liquid_positions:
return {'liquid_level_mm': 0.0, 'success': False}
# 取第一个检测区域的液位高度
first_position = next(iter(liquid_positions.values()))
liquid_level_mm = first_position.get('height_mm', 0.0)
# 构建兼容格式的结果
converted_result = {
'liquid_level_mm': liquid_level_mm,
'success': True,
'areas': {}
}
# 转换所有区域的数据
for idx, position_data in liquid_positions.items():
area_name = f'区域{idx + 1}'
converted_result['areas'][area_name] = {
'liquid_height': position_data.get('height_mm', 0.0),
'y_position': position_data.get('y', 0),
'height_px': position_data.get('height_px', 0)
}
return converted_result
except Exception as e:
print(f"[结果转换] 转换检测结果失败: {e}")
return {'liquid_level_mm': 0.0, 'success': False}
class ModelTestHandler:
......@@ -1777,7 +1842,6 @@ class ModelTestHandler:
<h4 style="margin: 0 0 10px 0; color: #333333; font-size: 16px; font-weight: 500;">检测结果视频</h4>
<video width="100%" height="auto" controls style="border: none; border-radius: 6px; max-height: 400px; background: #f8f9fa;">
<source src="file:///{video_path_formatted}" type="video/mp4">
您的浏览器不支持视频播放
</video>
</div>
......
......@@ -563,18 +563,17 @@ class ModelTrainingHandler(ModelTestHandler):
except Exception as add_error:
self._appendLog(f"[WARNING] 添加到模型集失败: {str(add_error)}\n")
# 刷新模型集管理页面
# 同步到所有相关页面
try:
self._refreshModelSetPage()
except Exception as refresh_error:
pass
try:
self._refreshModelTestPage()
except Exception as test_refresh_error:
pass
# 修复:使用非阻塞式通知,避免卡住UI
self._syncAllModelPages()
except Exception as sync_error:
self._appendLog(f"[WARNING] 同步页面失败: {str(sync_error)}\n")
# 获取 last.dat 路径(转换后应该是 dat 文件)
last_checkpoint_path = None
if self.current_exp_name and weights_dir:
last_dat_path = os.path.join(weights_dir, "last.dat")
last_pt_path = os.path.join(weights_dir, "last.pt")
self._appendLog("\n" + "="*70 + "\n")
self._appendLog(" 训练完成通知\n")
self._appendLog("="*70 + "\n")
......@@ -2388,12 +2387,12 @@ class ModelTrainingHandler(ModelTestHandler):
# 获取训练笔记
training_notes = self._getTrainingNotes()
# 将模型移动到detection_model目录
detection_model_path = self._moveModelToDetectionDir(best_model_path, model_name, weights_dir, training_notes)
# 模型已直接保存在detection_model目录中,进行后处理
detection_model_path = self._processModelInDetectionDir(best_model_path, model_name, weights_dir, training_notes)
if detection_model_path:
# 更新模型参数中的路径
model_params['path'] = detection_model_path
self._appendLog(f"模型已移动到detection_model目录: {os.path.basename(detection_model_path)}\n")
self._appendLog(f"模型已保存在detection_model目录: {os.path.basename(detection_model_path)}\n")
# 保存到模型集配置
self._saveModelToConfig(model_name, model_params)
......@@ -2524,93 +2523,81 @@ class ModelTrainingHandler(ModelTestHandler):
self._appendLog(f"[ERROR] 获取最新模型目录失败: {str(e)}\n")
return None
def _moveModelToDetectionDir(self, model_path, model_name, weights_dir, training_notes=""):
"""将训练完成的模型移动到detection_model目录"""
def _processModelInDetectionDir(self, model_path, model_name, weights_dir, training_notes=""):
"""处理已保存在detection_model目录中的模型"""
try:
import shutil
from pathlib import Path
self._appendLog(f"\n开始移动模型到detection_model目录...\n")
# 获取项目根目录
project_root = get_project_root()
detection_model_dir = os.path.join(project_root, 'database', 'model', 'detection_model')
# 确保detection_model目录存在
os.makedirs(detection_model_dir, exist_ok=True)
self._appendLog(f"\n开始处理detection_model目录中的模型...\n")
# 获取下一个可用的数字ID
existing_dirs = []
for item in os.listdir(detection_model_dir):
item_path = os.path.join(detection_model_dir, item)
if os.path.isdir(item_path) and item.isdigit():
existing_dirs.append(int(item))
next_id = max(existing_dirs) + 1 if existing_dirs else 1
target_model_dir = os.path.join(detection_model_dir, str(next_id))
self._appendLog(f" 目标目录: {target_model_dir}\n")
# 创建目标目录结构
os.makedirs(target_model_dir, exist_ok=True)
target_weights_dir = os.path.join(target_model_dir, 'weights')
os.makedirs(target_weights_dir, exist_ok=True)
# 移动整个weights目录的内容
# 模型已经直接保存在detection_model目录中,获取模型目录
if weights_dir and os.path.exists(weights_dir):
self._appendLog(f" 复制weights目录内容...\n")
# weights_dir应该是 detection_model/{ID}/weights
target_model_dir = os.path.dirname(weights_dir)
model_id = os.path.basename(target_model_dir)
# 复制所有文件到目标weights目录
for filename in os.listdir(weights_dir):
source_file = os.path.join(weights_dir, filename)
target_file = os.path.join(target_weights_dir, filename)
if os.path.isfile(source_file):
shutil.copy2(source_file, target_file)
self._appendLog(f" 复制: {filename}\n")
# 复制训练目录的其他文件(如config.yaml, results.csv等)
train_exp_dir = os.path.dirname(weights_dir)
if os.path.exists(train_exp_dir):
for filename in os.listdir(train_exp_dir):
if filename != 'weights': # 跳过weights目录
source_file = os.path.join(train_exp_dir, filename)
target_file = os.path.join(target_model_dir, filename)
if os.path.isfile(source_file):
shutil.copy2(source_file, target_file)
self._appendLog(f" 复制配置: {filename}\n")
# 保存训练笔记(如果有)
if training_notes:
notes_file = os.path.join(target_model_dir, 'training_notes.txt')
try:
with open(notes_file, 'w', encoding='utf-8') as f:
# 添加时间戳和模型信息
f.write(f"训练笔记 - {model_name}\n")
f.write(f"训练时间: {self._getCurrentTimestamp()}\n")
f.write(f"模型ID: {next_id}\n")
f.write("="*50 + "\n\n")
f.write(training_notes)
self._appendLog(f" 保存训练笔记: training_notes.txt\n")
except Exception as e:
self._appendLog(f" 保存训练笔记失败: {str(e)}\n")
# 确定最终的模型文件路径
model_filename = os.path.basename(model_path)
final_model_path = os.path.join(target_weights_dir, model_filename)
if os.path.exists(final_model_path):
self._appendLog(f"✅ 模型已成功移动到detection_model/{next_id}/weights/{model_filename}\n")
self._appendLog(f" 模型目录: {target_model_dir}\n")
self._appendLog(f" 模型ID: {model_id}\n")
# 清理weights目录中的所有pt文件
self._appendLog(f" 清理pt文件...\n")
self._forceCleanupPtFiles(weights_dir)
# 验证pt文件已被清理
remaining_pt_files = [f for f in os.listdir(weights_dir) if f.endswith('.pt')]
if remaining_pt_files:
self._appendLog(f" 警告: 仍有pt文件未清理: {remaining_pt_files}\n")
else:
self._appendLog(f" ✅ 所有pt文件已清理完成\n")
# 保存训练笔记(如果有)
if training_notes:
self._appendLog(f"✅ 训练笔记已保存到detection_model/{next_id}/training_notes.txt\n")
return final_model_path
notes_file = os.path.join(target_model_dir, 'training_notes.txt')
try:
with open(notes_file, 'w', encoding='utf-8') as f:
# 添加时间戳和模型信息
f.write(f"训练笔记 - {model_name}\n")
f.write(f"训练时间: {self._getCurrentTimestamp()}\n")
f.write(f"模型ID: {model_id}\n")
f.write("="*50 + "\n\n")
f.write(training_notes)
self._appendLog(f" 保存训练笔记: training_notes.txt\n")
except Exception as e:
self._appendLog(f" 保存训练笔记失败: {str(e)}\n")
# 查找实际的模型文件(应该是.dat格式)
final_model_path = None
if os.path.exists(weights_dir):
for filename in os.listdir(weights_dir):
if filename.startswith('best.') and filename.endswith('.dat'):
final_model_path = os.path.join(weights_dir, filename)
break
# 如果没找到best.dat,查找任何.dat文件
if not final_model_path:
for filename in os.listdir(weights_dir):
if filename.endswith('.dat'):
final_model_path = os.path.join(weights_dir, filename)
break
if final_model_path and os.path.exists(final_model_path):
self._appendLog(f"✅ 模型已保存在detection_model/{model_id}/weights/{os.path.basename(final_model_path)}\n")
if training_notes:
self._appendLog(f"✅ 训练笔记已保存到detection_model/{model_id}/training_notes.txt\n")
# 自动生成模型描述文件
self._generateModelDescription(target_model_dir, model_id, final_model_path, model_name)
return final_model_path
else:
self._appendLog(f"❌ 未找到模型文件在: {weights_dir}\n")
return None
else:
self._appendLog(f"❌ 模型移动失败,目标文件不存在: {final_model_path}\n")
self._appendLog(f"❌ weights目录不存在: {weights_dir}\n")
return None
except Exception as e:
self._appendLog(f"❌ [ERROR] 移动模型到detection_model失败: {str(e)}\n")
self._appendLog(f"❌ [ERROR] 处理detection_model目录中的模型失败: {str(e)}\n")
import traceback
traceback.print_exc()
return None
......@@ -3037,7 +3024,7 @@ class ModelTrainingHandler(ModelTestHandler):
traceback.print_exc()
def _refreshModelTestPage(self):
"""刷新模型测试页面的模型列表"""
"""刷新模型测试页面"""
try:
# 获取主窗口的模型测试页面
if hasattr(self, 'main_window') and hasattr(self.main_window, 'testModelPage'):
......@@ -3046,9 +3033,8 @@ class ModelTrainingHandler(ModelTestHandler):
if hasattr(self.main_window, 'modelSetPage'):
self.main_window.testModelPage.loadModelsFromModelSetPage(self.main_window.modelSetPage)
else:
# 否则直接刷新模型测试页面
if hasattr(self.main_window.testModelPage, 'loadModelsFromConfig'):
self.main_window.testModelPage.loadModelsFromConfig()
# 备用方案:直接加载模型
self.main_window.testModelPage.loadModels()
print("[信息] 模型测试页面刷新完成")
else:
print("[警告] 无法访问模型测试页面")
......@@ -3057,6 +3043,60 @@ class ModelTrainingHandler(ModelTestHandler):
import traceback
traceback.print_exc()
def _syncAllModelPages(self):
"""同步所有模型相关页面"""
try:
self._appendLog("\n正在同步所有模型页面...\n")
# 1. 刷新模型集管理页面
try:
self._refreshModelSetPage()
self._appendLog(" ✅ 模型集管理页面已同步\n")
except Exception as e:
self._appendLog(f" ❌ 模型集管理页面同步失败: {str(e)}\n")
# 2. 刷新模型测试页面
try:
self._refreshModelTestPage()
self._appendLog(" ✅ 模型测试页面已同步\n")
except Exception as e:
self._appendLog(f" ❌ 模型测试页面同步失败: {str(e)}\n")
# 3. 刷新训练页面的模型下拉菜单
try:
self._refreshTrainingPageModels()
self._appendLog(" ✅ 训练页面模型列表已同步\n")
except Exception as e:
self._appendLog(f" ❌ 训练页面模型列表同步失败: {str(e)}\n")
# 4. 通知主窗口更新相关UI
try:
if hasattr(self, 'main_window') and hasattr(self.main_window, 'updateModelRelatedUI'):
self.main_window.updateModelRelatedUI()
self._appendLog(" ✅ 主窗口UI已同步\n")
except Exception as e:
self._appendLog(f" ❌ 主窗口UI同步失败: {str(e)}\n")
self._appendLog("页面同步完成\n")
except Exception as e:
self._appendLog(f"❌ 页面同步失败: {str(e)}\n")
import traceback
traceback.print_exc()
def _refreshTrainingPageModels(self):
"""刷新训练页面的模型下拉菜单"""
try:
if hasattr(self, 'training_panel') and hasattr(self.training_panel, 'refreshBaseModelList'):
self.training_panel.refreshBaseModelList()
elif hasattr(self, 'main_window') and hasattr(self.main_window, 'trainingPage'):
if hasattr(self.main_window.trainingPage, 'refreshBaseModelList'):
self.main_window.trainingPage.refreshBaseModelList()
except Exception as e:
print(f"[错误] 刷新训练页面模型列表失败: {e}")
import traceback
traceback.print_exc()
def _getFileSize(self, file_path):
"""获取文件大小(格式化字符串)"""
try:
......@@ -3249,3 +3289,117 @@ class ModelTrainingHandler(ModelTestHandler):
import traceback
self._appendLog(traceback.format_exc())
return None
def _generateModelDescription(self, target_model_dir, model_id, model_path, model_name):
"""自动生成模型描述文件"""
try:
import time
from pathlib import Path
self._appendLog(f"\n开始生成模型描述文件...\n")
# 新结构:查找根目录下的dat模型文件
actual_model_path = None
# 优先查找best.dat文件(在根目录)
for filename in os.listdir(target_model_dir):
if filename.startswith('best.') and filename.endswith('.dat'):
actual_model_path = os.path.join(target_model_dir, filename)
break
# 如果没找到best.dat,查找任何.dat文件(在根目录)
if not actual_model_path:
for filename in os.listdir(target_model_dir):
if filename.endswith('.dat'):
actual_model_path = os.path.join(target_model_dir, filename)
break
# 使用找到的dat文件路径,如果没找到则使用原路径
if actual_model_path and os.path.exists(actual_model_path):
model_path = actual_model_path
self._appendLog(f" 使用模型文件: {os.path.basename(model_path)}\n")
# 获取模型文件信息
model_file = Path(model_path)
if os.path.exists(model_path):
file_size_bytes = model_file.stat().st_size
file_size_mb = round(file_size_bytes / (1024 * 1024), 1)
mod_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(model_file.stat().st_mtime))
else:
file_size_mb = "未知"
mod_time = self._getCurrentTimestamp()
# 确定模型类型(现在应该都是.dat格式)
model_type = ".dat"
# 生成模型描述内容
description_content = f"""【模型信息】{model_id}/{model_file.stem}
基础信息
------------------------------
模型名称: {model_id}/{model_file.stem}
模型类型: {model_type}
模型路径: {model_path}
文件大小: {file_size_mb} MB
最后修改: {mod_time}
模型说明
------------------------------
这是一个液位检测专用模型,用于识别和检测液位线位置。
模型采用深度学习技术,能够准确识别各种容器中的液位状态。
本模型通过训练升级生成,具有更好的检测精度和稳定性。
技术特性
------------------------------
- 支持多种液体类型检测
- 高精度液位线定位
- 实时检测能力
- 适用于工业自动化场景
- 经过训练优化的检测算法
通用模型特性
------------------------------
- 基于深度学习技术
- 专门针对液位检测优化
- 支持实时处理
- 适用于工业环境
使用说明
------------------------------
1. 模型已经过训练升级,可直接用于液位检测任务
2. 支持图像输入,输出液位线坐标和置信度
3. 建议在良好光照条件下使用以获得最佳效果
4. 如需进一步提升特定场景的检测效果,可继续进行模型微调训练
注意事项
------------------------------
- 确保输入图像清晰度足够
- 避免强反光或阴影干扰
- 定期验证模型检测精度
- 如发现检测效果下降,建议重新训练或更新模型
- 新训练的模型建议先在测试环境中验证效果
训练信息
------------------------------
训练时间: {self._getCurrentTimestamp()}
模型版本: {model_id}
训练状态: 已完成
备注: 通过系统自动训练升级生成"""
# 保存模型描述文件到training_results目录
training_results_dir = os.path.join(target_model_dir, 'training_results')
os.makedirs(training_results_dir, exist_ok=True)
description_file = os.path.join(training_results_dir, '模型描述.txt')
with open(description_file, 'w', encoding='utf-8') as f:
f.write(description_content)
self._appendLog(f"✅ 模型描述文件已生成: {description_file}\n")
self._appendLog(f" 文件包含完整的模型信息和使用说明\n")
return description_file
except Exception as e:
self._appendLog(f"❌ 生成模型描述文件失败: {str(e)}\n")
import traceback
traceback.print_exc()
return None
......@@ -709,26 +709,38 @@ class TrainingWorker(QThread):
if device_str.lower() in ['cpu', '-1']:
workers = 0 # CPU模式下禁用多线程数据加载
# 获取下一个可用的模型ID并创建目录
# 获取下一个可用的模型ID并创建目录(直接保存到detection_model)
project_root = get_project_root()
train_model_dir = os.path.join(project_root, 'database', 'model', 'train_model')
os.makedirs(train_model_dir, exist_ok=True)
detection_model_dir = os.path.join(project_root, 'database', 'model', 'detection_model')
os.makedirs(detection_model_dir, exist_ok=True)
# 查找下一个可用的数字ID
existing_dirs = []
for item in os.listdir(train_model_dir):
item_path = os.path.join(train_model_dir, item)
for item in os.listdir(detection_model_dir):
item_path = os.path.join(detection_model_dir, item)
if os.path.isdir(item_path) and item.isdigit():
existing_dirs.append(int(item))
next_id = max(existing_dirs) + 1 if existing_dirs else 1
model_output_dir = os.path.join(train_model_dir, str(next_id))
model_output_dir = os.path.join(detection_model_dir, str(next_id))
# 预先设置weights_dir(YOLO会在project目录下创建weights子目录)
weights_dir = os.path.join(model_output_dir, "weights")
# 创建新的目录结构
os.makedirs(model_output_dir, exist_ok=True)
training_results_dir = os.path.join(model_output_dir, "training_results")
test_results_dir = os.path.join(model_output_dir, "test_results")
os.makedirs(training_results_dir, exist_ok=True)
os.makedirs(test_results_dir, exist_ok=True)
# YOLO训练时使用临时目录,训练完成后移动文件
temp_training_dir = os.path.join(model_output_dir, "temp_training")
weights_dir = os.path.join(temp_training_dir, "weights")
self.training_report["weights_dir"] = weights_dir
self.training_report["model_output_dir"] = model_output_dir
self.training_report["training_results_dir"] = training_results_dir
self.log_output.emit(f"模型将直接保存到: {model_output_dir}\n")
self.log_output.emit(f"模型将保存到: {model_output_dir}\n")
self.log_output.emit(f"训练结果将保存到: {training_results_dir}\n")
self.log_output.emit(f"测试结果将保存到: {test_results_dir}\n")
# 开始训练
try:
......@@ -742,7 +754,7 @@ class TrainingWorker(QThread):
optimizer=self.training_params['optimizer'],
close_mosaic=self.training_params['close_mosaic'],
resume=self.training_params['resume'],
project=model_output_dir,
project=temp_training_dir,
name='', # 空名称,直接使用project目录
single_cls=self.training_params['single_cls'],
cache=False,
......@@ -772,36 +784,26 @@ class TrainingWorker(QThread):
if save_dir:
save_dir_abs = os.path.abspath(str(save_dir))
weights_dir = os.path.abspath(os.path.join(save_dir_abs, "weights"))
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"[调试] 实际保存目录: {save_dir_abs}\n")
self.log_output.emit(f"[调试] 实际权重目录: {weights_dir}\n")
# 立即转换PT文件为DAT格式并删除PT文件
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
# 整理训练结果到新的目录结构
model_output_dir = self.training_report.get("model_output_dir")
training_results_dir = self.training_report.get("training_results_dir")
if model_output_dir and training_results_dir:
self._organizeTrainingResults(save_dir_abs, model_output_dir, training_results_dir)
else:
# 备用方案:使用预设的model_output_dir
# 备用方案:使用预设的temp_training_dir
self.log_output.emit("\n[WARNING] 无法从trainer获取保存目录,使用预设目录\n")
if 'model_output_dir' in locals():
# 查找实际的weights目录(可能在train子目录下)
possible_weights_dirs = [
os.path.join(model_output_dir, "train", "weights"),
os.path.join(model_output_dir, "weights")
]
for possible_dir in possible_weights_dirs:
if os.path.exists(possible_dir):
weights_dir = possible_dir
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"[调试] 找到权重目录: {weights_dir}\n")
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
break
model_output_dir = self.training_report.get("model_output_dir")
training_results_dir = self.training_report.get("training_results_dir")
if model_output_dir and training_results_dir and 'temp_training_dir' in locals():
if os.path.exists(temp_training_dir):
self._organizeTrainingResults(temp_training_dir, model_output_dir, training_results_dir)
else:
self.log_output.emit(f"[ERROR] 未找到权重目录,跳过转换\n")
self.log_output.emit(f"[ERROR] 临时训练目录不存在: {temp_training_dir}\n")
else:
self.log_output.emit(f"[ERROR] model_output_dir 未定义,跳过转换\n")
self.log_output.emit(f"[ERROR] 必要的目录信息未定义,跳过整理\n")
except Exception as convert_err:
self.log_output.emit(f"\n[ERROR] 转换过程出错: {convert_err}\n")
import traceback
......@@ -840,36 +842,26 @@ class TrainingWorker(QThread):
if save_dir:
save_dir_abs = os.path.abspath(str(save_dir))
weights_dir = os.path.abspath(os.path.join(save_dir_abs, "weights"))
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"[调试] 实际保存目录: {save_dir_abs}\n")
self.log_output.emit(f"[调试] 实际权重目录: {weights_dir}\n")
# 立即转换PT文件为DAT格式并删除PT文件
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
# 整理训练结果到新的目录结构
model_output_dir = self.training_report.get("model_output_dir")
training_results_dir = self.training_report.get("training_results_dir")
if model_output_dir and training_results_dir:
self._organizeTrainingResults(save_dir_abs, model_output_dir, training_results_dir)
else:
# 备用方案:使用预设的model_output_dir
# 备用方案:使用预设的temp_training_dir
self.log_output.emit("\n[WARNING] 无法从trainer获取保存目录,使用预设目录\n")
if 'model_output_dir' in locals():
# 查找实际的weights目录(可能在train子目录下)
possible_weights_dirs = [
os.path.join(model_output_dir, "train", "weights"),
os.path.join(model_output_dir, "weights")
]
for possible_dir in possible_weights_dirs:
if os.path.exists(possible_dir):
weights_dir = possible_dir
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"[调试] 找到权重目录: {weights_dir}\n")
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
break
model_output_dir = self.training_report.get("model_output_dir")
training_results_dir = self.training_report.get("training_results_dir")
if model_output_dir and training_results_dir and 'temp_training_dir' in locals():
if os.path.exists(temp_training_dir):
self._organizeTrainingResults(temp_training_dir, model_output_dir, training_results_dir)
else:
self.log_output.emit(f"[ERROR] 未找到权重目录,跳过转换\n")
self.log_output.emit(f"[ERROR] 临时训练目录不存在: {temp_training_dir}\n")
else:
self.log_output.emit(f"[ERROR] model_output_dir 未定义,跳过转换\n")
self.log_output.emit(f"[ERROR] 必要的目录信息未定义,跳过整理\n")
except Exception as convert_err:
self.log_output.emit(f"\n[ERROR] 转换过程出错: {convert_err}\n")
import traceback
......@@ -885,15 +877,14 @@ class TrainingWorker(QThread):
save_dir = getattr(getattr(model, "trainer", None), "save_dir", None)
if save_dir:
save_dir_abs = os.path.abspath(str(save_dir))
weights_dir = os.path.abspath(os.path.join(save_dir_abs, "weights"))
self.training_report["weights_dir"] = weights_dir
self.log_output.emit(f"\n[调试] 实际保存目录: {save_dir_abs}\n")
self.log_output.emit(f"[调试] 实际权重目录: {weights_dir}\n")
# 立即转换PT文件为DAT格式并删除PT文件
self.log_output.emit("\n正在转换模型文件为DAT格式...\n")
self._convertPtToDatAndCleanup(weights_dir)
# 整理训练结果到新的目录结构
model_output_dir = self.training_report.get("model_output_dir")
training_results_dir = self.training_report.get("training_results_dir")
if model_output_dir and training_results_dir:
self._organizeTrainingResults(save_dir_abs, model_output_dir, training_results_dir)
else:
self.log_output.emit("\n[WARNING] 无法获取模型保存目录,跳过转换\n")
except Exception as convert_err:
......@@ -1271,3 +1262,89 @@ class TrainingWorker(QThread):
def has_training_started(self):
"""检查训练是否已经真正开始"""
return self.training_actually_started
def _organizeTrainingResults(self, temp_training_dir, model_output_dir, training_results_dir):
"""
整理训练结果到新的目录结构
Args:
temp_training_dir: 临时训练目录
model_output_dir: 模型输出根目录
training_results_dir: 训练结果目录
"""
try:
import shutil
self.log_output.emit("\n正在整理训练结果文件...\n")
# 1. 处理权重文件 - 转换并移动到根目录
weights_dir = os.path.join(temp_training_dir, "weights")
if os.path.exists(weights_dir):
self.log_output.emit("正在转换并移动模型文件...\n")
# 转换PT文件为DAT格式
self._convertPtToDatAndCleanup(weights_dir)
# 移动DAT文件到根目录
for filename in os.listdir(weights_dir):
if filename.endswith('.dat'):
src_path = os.path.join(weights_dir, filename)
dst_path = os.path.join(model_output_dir, filename)
shutil.move(src_path, dst_path)
self.log_output.emit(f" 移动模型文件: {filename}\n")
# 2. 移动训练结果文件到training_results目录
training_files = [
'results.csv', # 训练结果
'args.yaml', # 训练参数
'config.yaml', # 配置文件
'labels.jpg', # 标签分布图
'labels_correlogram.jpg', # 标签相关图
'train_batch*.jpg', # 训练批次图片
'val_batch*.jpg', # 验证批次图片
'confusion_matrix.png', # 混淆矩阵
'F1_curve.png', # F1曲线
'P_curve.png', # 精确率曲线
'R_curve.png', # 召回率曲线
'PR_curve.png', # PR曲线
'results.png' # 结果图表
]
# 移动单个文件
for pattern in training_files:
if '*' in pattern:
# 处理通配符模式
import glob
base_pattern = pattern.replace('*', '*')
matches = glob.glob(os.path.join(temp_training_dir, base_pattern))
for match in matches:
if os.path.isfile(match):
filename = os.path.basename(match)
dst_path = os.path.join(training_results_dir, filename)
shutil.move(match, dst_path)
self.log_output.emit(f" 移动训练文件: {filename}\n")
else:
src_path = os.path.join(temp_training_dir, pattern)
if os.path.exists(src_path):
dst_path = os.path.join(training_results_dir, pattern)
shutil.move(src_path, dst_path)
self.log_output.emit(f" 移动训练文件: {pattern}\n")
# 3. 移动plots目录(如果存在)
plots_src = os.path.join(temp_training_dir, "plots")
if os.path.exists(plots_src):
plots_dst = os.path.join(training_results_dir, "plots")
shutil.move(plots_src, plots_dst)
self.log_output.emit(" 移动训练图表目录: plots/\n")
# 4. 清理临时目录
if os.path.exists(temp_training_dir):
shutil.rmtree(temp_training_dir)
self.log_output.emit(" 清理临时目录\n")
self.log_output.emit("训练结果整理完成!\n")
except Exception as e:
self.log_output.emit(f"[错误] 整理训练结果失败: {e}\n")
import traceback
traceback.print_exc()
......@@ -120,7 +120,7 @@ class ModelSetPage(QtWidgets.QWidget):
# 标题
title = QtWidgets.QLabel("模型集管理")
title.setStyleSheet("font-size: 13pt; font-weight: bold;")
FontManager.applyToWidget(title, weight=FontManager.WEIGHT_BOLD) # 使用全局默认字体,加粗
header_layout.addWidget(title)
header_layout.addStretch()
......@@ -246,10 +246,26 @@ class ModelSetPage(QtWidgets.QWidget):
content_layout.addWidget(left_widget)
# ========== 中间主内容区(左右分栏:模型列表 + 文本显示) ==========
center_widget = QtWidgets.QWidget()
center_main_layout = QtWidgets.QHBoxLayout(center_widget)
center_main_layout.setContentsMargins(0, 0, 0, 0)
center_main_layout.setSpacing(8)
# 使用QSplitter实现可缩放的分隔符
center_splitter = QtWidgets.QSplitter(Qt.Horizontal)
center_splitter.setChildrenCollapsible(False) # 防止面板被完全折叠
# 设置splitter样式
center_splitter.setStyleSheet("""
QSplitter::handle {
background-color: #e0e0e0;
border: 1px solid #c0c0c0;
width: 6px;
margin: 2px;
border-radius: 3px;
}
QSplitter::handle:hover {
background-color: #d0d0d0;
}
QSplitter::handle:pressed {
background-color: #b0b0b0;
}
""")
# ===== 左侧:模型列表区域 =====
left_list_widget = QtWidgets.QWidget()
......@@ -338,17 +354,22 @@ class ModelSetPage(QtWidgets.QWidget):
""")
self.model_info_text.setPlaceholderText("选择模型后将显示模型文件夹内的txt文件内容...")
# 应用全局字体管理
FontManager.applyToWidget(self.model_info_text, size=10)
FontManager.applyToWidget(self.model_info_text) # 使用全局默认字体
right_text_layout.addWidget(self.model_info_text)
# 将左右两部分添加到中间主布局(1:1比例)
center_main_layout.addWidget(left_list_widget, 1)
center_main_layout.addWidget(right_text_widget, 1)
# 将左右两部分添加到splitter
center_splitter.addWidget(left_list_widget)
center_splitter.addWidget(right_text_widget)
# 设置初始比例(1:1)
center_splitter.setSizes([400, 400])
center_splitter.setStretchFactor(0, 1) # 左侧可伸缩
center_splitter.setStretchFactor(1, 1) # 右侧可伸缩
# 🔥 设置中间区域的宽度(主要内容区)- 响应式布局
center_widget.setMinimumWidth(scale_w(800))
center_splitter.setMinimumWidth(scale_w(800))
content_layout.addWidget(center_widget, 3)
content_layout.addWidget(center_splitter, 3)
main_layout.addLayout(content_layout)
......@@ -558,7 +579,7 @@ class ModelSetPage(QtWidgets.QWidget):
if '预训练模型' not in model_type and 'YOLO' not in model_type:
model_type = "PyTorch (YOLOv8)"
elif model_file.endswith('.dat'):
model_type = "加密模型 (.dat)"
model_type = ".dat"
elif model_file.endswith('.onnx'):
model_type = "ONNX"
......@@ -1216,7 +1237,6 @@ class ModelSetPage(QtWidgets.QWidget):
border-radius: 4px;
}
QLabel {
font-size: 10pt;
padding: 3px;
}
""")
......@@ -1294,6 +1314,9 @@ class ModelSetPage(QtWidgets.QWidget):
layout.addLayout(button_layout)
# 应用全局字体管理器到对话框及其所有子控件
FontManager.applyToWidgetRecursive(dialog)
# 显示对话框
dialog.exec_()
......@@ -1381,8 +1404,8 @@ class ModelSetPage(QtWidgets.QWidget):
# 2. 从配置文件提取通道模型
channel_models = self._extractChannelModels(config)
# 3. 扫描模型目录
scanned_models = self._scanModelDirectory()
# 3. 扫描模型目录(使用统一的getDetectionModels方法确保名称一致)
scanned_models = self.getDetectionModels()
# 4. 合并所有模型信息
all_models = self._mergeModelInfo(channel_models, scanned_models)
......@@ -1469,8 +1492,11 @@ class ModelSetPage(QtWidgets.QWidget):
# 3. 检查模型路径是否存在
if model_path and os.path.exists(model_path):
# 使用统一的命名逻辑:从模型路径推导出标准名称
model_name = self._getModelNameFromPath(model_path)
models.append({
'name': f"{channel_name}模型",
'name': model_name, # 使用统一的模型名称
'path': model_path,
'channel': channel_key,
'channel_name': channel_name,
......@@ -1480,19 +1506,64 @@ class ModelSetPage(QtWidgets.QWidget):
return models
def _getModelNameFromPath(self, model_path):
"""从模型路径推导出统一的模型名称"""
try:
from pathlib import Path
path = Path(model_path)
# 检查是否在detection_model目录下
if 'detection_model' in path.parts:
# 找到detection_model目录的索引
parts = path.parts
detection_index = -1
for i, part in enumerate(parts):
if part == 'detection_model':
detection_index = i
break
if detection_index >= 0 and detection_index + 1 < len(parts):
# 获取模型ID目录名
model_id = parts[detection_index + 1]
# 尝试读取config.yaml获取模型名称
model_dir = path.parent
config_locations = [
model_dir / "training_results" / "config.yaml",
model_dir / "config.yaml"
]
for config_file in config_locations:
if config_file.exists():
try:
import yaml
with open(config_file, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
return config_data.get('model_name', f"模型_{model_id}")
except:
continue
# 如果没有配置文件,使用默认格式
return f"模型_{model_id}"
# 如果不在detection_model目录下,使用文件名
return path.stem
except Exception:
# 出错时返回文件名
return Path(model_path).stem
def _scanModelDirectory(self):
"""扫描模型目录获取所有模型文件(优先detection_model)"""
"""扫描模型目录获取所有模型文件(只从detection_model加载)"""
models = []
try:
# 获取模型目录路径
current_dir = Path(__file__).parent.parent.parent
# 扫描多个模型目录(优先detection_model)
# 只扫描detection_model目录,确保数据一致性
model_dirs = [
(current_dir / "database" / "model" / "detection_model", "检测模型", True), # 优先级最高
(current_dir / "database" / "model" / "train_model", "训练模型", False),
(current_dir / "database" / "model" / "test_model", "测试模型", False)
(current_dir / "database" / "model" / "detection_model", "检测模型", True), # 唯一数据源
]
for model_dir, dir_type, is_primary in model_dirs:
......@@ -1604,6 +1675,104 @@ class ModelSetPage(QtWidgets.QWidget):
return models
@staticmethod
def getDetectionModels():
"""静态方法:获取detection_model目录下的所有模型,供其他页面使用"""
models = []
try:
# 获取模型目录路径
current_dir = Path(__file__).parent.parent.parent
detection_model_dir = current_dir / "database" / "model" / "detection_model"
if not detection_model_dir.exists():
return models
# 遍历所有子目录,包含数字和非数字目录
all_subdirs = [d for d in detection_model_dir.iterdir() if d.is_dir()]
# 分离数字目录和非数字目录
digit_subdirs = [d for d in all_subdirs if d.name.isdigit()]
non_digit_subdirs = [d for d in all_subdirs if not d.name.isdigit()]
# 数字目录按数字降序排列,非数字目录按字母排序
sorted_digit_subdirs = sorted(digit_subdirs, key=lambda x: int(x.name), reverse=True)
sorted_non_digit_subdirs = sorted(non_digit_subdirs, key=lambda x: x.name)
# 合并:数字目录在前(最新的在前),非数字目录在后
sorted_subdirs = sorted_digit_subdirs + sorted_non_digit_subdirs
for subdir in sorted_subdirs:
# 新结构:.dat文件直接在根目录下
# 查找.dat文件(优先在根目录,兼容旧的weights子目录)
dat_files = list(subdir.glob("*.dat"))
# 如果根目录没有.dat文件,检查weights子目录(兼容旧结构)
if not dat_files:
weights_dir = subdir / "weights"
if weights_dir.exists():
dat_files = list(weights_dir.glob("*.dat"))
if not dat_files:
continue
# 优先选择best.dat,然后是其他.dat文件
model_file = None
for dat_file in dat_files:
if dat_file.name.startswith('best.'):
model_file = dat_file
break
if not model_file:
model_file = dat_files[0]
# 尝试读取config.yaml获取模型名称
# 新结构:config.yaml在training_results子目录中
model_name = None
config_locations = [
subdir / "training_results" / "config.yaml", # 新结构
subdir / "config.yaml" # 兼容旧结构
]
for config_file in config_locations:
if config_file.exists():
try:
import yaml
with open(config_file, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
model_name = config_data.get('model_name', f"模型_{subdir.name}")
break
except:
continue
if not model_name:
model_name = f"模型_{subdir.name}"
# 获取文件大小
file_size = "未知"
try:
size_bytes = model_file.stat().st_size
if size_bytes < 1024 * 1024:
file_size = f"{size_bytes / 1024:.1f} KB"
else:
file_size = f"{size_bytes / (1024 * 1024):.1f} MB"
except:
pass
models.append({
'name': model_name,
'path': str(model_file),
'size': file_size,
'type': '.dat',
'source': '检测模型',
'model_id': subdir.name,
'is_primary': True
})
except Exception as e:
print(f"[错误] 获取detection_model模型失败: {e}")
return models
def _mergeModelInfo(self, channel_models, scanned_models):
"""合并模型信息,避免重复(优先detection_model)"""
all_models = []
......@@ -1669,7 +1838,7 @@ class ModelSetPage(QtWidgets.QWidget):
if model_path.endswith('.pt'):
model_type = "PyTorch (YOLOv5/v8)"
elif model_path.endswith('.dat'):
model_type = "加密模型 (.dat)"
model_type = ".dat"
else:
model_type = "未知格式"
......@@ -1793,7 +1962,7 @@ class ModelSetPage(QtWidgets.QWidget):
return False
def _loadModelTxtFiles(self, model_name):
"""读取并显示模型文件夹内的txt文件"""
"""显示模型最近一次升级的参数和训练效果信息"""
try:
# 清空文本显示区域
self.model_info_text.clear()
......@@ -1817,38 +1986,139 @@ class ModelSetPage(QtWidgets.QWidget):
self.model_info_text.setPlainText(f"模型目录不存在:\n{model_dir}")
return
# 查找目录中的所有txt文件
txt_files = list(model_dir.glob("*.txt"))
if not txt_files:
self.model_info_text.setPlainText(f"模型目录中没有找到txt文件:\n{model_dir}")
return
# 读取并显示所有txt文件的内容
# 构建模型训练信息
content_parts = []
content_parts.append(f"模型目录: {model_dir}\n")
content_parts.append("=" * 60 + "\n\n")
for txt_file in sorted(txt_files):
content_parts.append(f"【文件: {txt_file.name}】\n")
content_parts.append("-" * 60 + "\n")
# 首先检查是否有txt文件,如果有则优先显示txt文件内容
txt_files = []
# 新结构:搜索training_results目录的txt文件
training_results_dir = model_dir / "training_results"
if training_results_dir.exists():
training_results_files = list(training_results_dir.glob("*.txt"))
txt_files.extend(training_results_files)
# 兼容旧结构:搜索模型同级目录的txt文件
current_dir_files = list(model_dir.glob("*.txt"))
txt_files.extend(current_dir_files)
# 兼容旧结构:搜索上一级目录的txt文件
parent_dir = model_dir.parent
if parent_dir.exists():
parent_dir_files = list(parent_dir.glob("*.txt"))
txt_files.extend(parent_dir_files)
# 如果有txt文件,直接显示txt文件内容(优先显示模型描述文件)
if txt_files:
# 优先显示模型描述文件
description_files = [f for f in txt_files if '模型描述' in f.name]
other_files = [f for f in txt_files if '模型描述' not in f.name]
try:
# 尝试使用UTF-8编码读取
with open(txt_file, 'r', encoding='utf-8') as f:
file_content = f.read()
except UnicodeDecodeError:
# 如果UTF-8失败,尝试GBK编码
# 按优先级排序:模型描述文件在前
sorted_files = description_files + sorted(other_files)
for txt_file in sorted_files:
try:
with open(txt_file, 'r', encoding='gbk') as f:
file_content = f.read()
except Exception as e:
file_content = f"无法读取文件(编码错误): {str(e)}"
except Exception as e:
file_content = f"读取文件时出错: {str(e)}"
with open(txt_file, 'r', encoding='utf-8') as f:
file_content = f.read().strip()
if file_content:
if txt_file.parent == model_dir:
file_label = txt_file.name
else:
file_label = f"{txt_file.parent.name}/{txt_file.name}"
# 如果是模型描述文件,直接显示内容,不加标题
if '模型描述' in txt_file.name:
content_parts.append(f"{file_content}\n\n")
else:
content_parts.append(f"=== {file_label} ===\n")
content_parts.append(f"{file_content}\n\n")
except:
try:
with open(txt_file, 'r', encoding='gbk') as f:
file_content = f.read().strip()
if file_content:
if txt_file.parent == model_dir:
file_label = txt_file.name
else:
file_label = f"{txt_file.parent.name}/{txt_file.name}"
# 如果是模型描述文件,直接显示内容,不加标题
if '模型描述' in txt_file.name:
content_parts.append(f"{file_content}\n\n")
else:
content_parts.append(f"=== {file_label} ===\n")
content_parts.append(f"{file_content}\n\n")
except:
pass
# 如果找到了txt文件,直接显示,不再添加其他自动生成的信息
full_content = "".join(content_parts)
self.model_info_text.setPlainText(full_content)
# 滚动到顶部
cursor = self.model_info_text.textCursor()
cursor.movePosition(QtGui.QTextCursor.Start)
self.model_info_text.setTextCursor(cursor)
return
else:
# 如果没有txt文件,则显示基础信息
content_parts.append(f"【模型升级信息】 - {model_name}\n")
content_parts.append("=" * 60 + "\n\n")
content_parts.append("基础信息\n")
content_parts.append("-" * 30 + "\n")
content_parts.append(f"模型名称: {model_name}\n")
content_parts.append(f"模型类型: {model_info.get('type', '未知')}\n")
content_parts.append(f"模型路径: {model_path}\n")
content_parts.append(f"文件大小: {model_info.get('size', '未知')}\n")
content_parts.append(file_content)
content_parts.append("\n\n" + "=" * 60 + "\n\n")
# 获取文件修改时间
try:
if os.path.exists(model_path):
import time
mtime = os.path.getmtime(model_path)
mod_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(mtime))
content_parts.append(f"最后修改: {mod_time}\n")
except:
pass
content_parts.append("\n")
# 2. 训练参数信息
training_info = self._getTrainingParameters(model_dir)
if training_info:
content_parts.append("训练参数\n")
content_parts.append("-" * 30 + "\n")
for key, value in training_info.items():
content_parts.append(f"{key}: {value}\n")
content_parts.append("\n")
# 3. 训练效果信息
results_info = self._getTrainingResults(model_dir)
if results_info:
content_parts.append("训练效果\n")
content_parts.append("-" * 30 + "\n")
for key, value in results_info.items():
content_parts.append(f"{key}: {value}\n")
content_parts.append("\n")
# 4. 模型评估
evaluation = self._evaluateModelPerformance(results_info)
if evaluation:
content_parts.append("效果评估\n")
content_parts.append("-" * 30 + "\n")
content_parts.append(f"{evaluation}\n\n")
# 如果没有找到训练信息,显示基础信息
if not training_info and not results_info:
content_parts.append("说明\n")
content_parts.append("-" * 30 + "\n")
content_parts.append("未找到详细的训练记录文件。\n")
content_parts.append("这可能是预训练模型或外部导入的模型。\n\n")
content_parts.append("模型仍可正常使用,如需查看训练效果,\n")
content_parts.append("建议重新进行模型升级训练。\n")
# 显示所有内容
full_content = "".join(content_parts)
......@@ -1860,7 +2130,199 @@ class ModelSetPage(QtWidgets.QWidget):
self.model_info_text.setTextCursor(cursor)
except Exception as e:
self.model_info_text.setPlainText(f"读取txt文件时出错:\n{str(e)}")
self.model_info_text.setPlainText(f"读取模型信息时出错:\n{str(e)}")
def _getTrainingParameters(self, model_dir):
"""获取训练参数信息"""
training_params = {}
try:
# 查找 args.yaml 文件(YOLO训练参数)
args_file = model_dir / "args.yaml"
if args_file.exists():
with open(args_file, 'r', encoding='utf-8') as f:
args_data = yaml.safe_load(f)
if args_data:
training_params["训练轮数"] = args_data.get('epochs', '未知')
training_params["批次大小"] = args_data.get('batch', '未知')
training_params["图像尺寸"] = args_data.get('imgsz', '未知')
training_params["学习率"] = args_data.get('lr0', '未知')
training_params["优化器"] = args_data.get('optimizer', '未知')
training_params["数据集"] = args_data.get('data', '未知')
training_params["设备"] = args_data.get('device', '未知')
training_params["工作进程"] = args_data.get('workers', '未知')
# 查找其他可能的配置文件
config_files = list(model_dir.glob("*config*.yaml")) + list(model_dir.glob("*config*.yml"))
for config_file in config_files:
if config_file.name != "args.yaml":
try:
with open(config_file, 'r', encoding='utf-8') as f:
config_data = yaml.safe_load(f)
if config_data and isinstance(config_data, dict):
# 提取一些关键参数
if 'epochs' in config_data:
training_params["训练轮数"] = config_data['epochs']
if 'batch_size' in config_data:
training_params["批次大小"] = config_data['batch_size']
except:
continue
except Exception as e:
pass
return training_params
def _getTrainingResults(self, model_dir):
"""获取训练结果信息"""
results = {}
try:
# 查找 results.csv 文件(YOLO训练结果)
results_file = model_dir / "results.csv"
if results_file.exists():
import pandas as pd
try:
df = pd.read_csv(results_file)
if not df.empty:
# 获取最后一行的结果(最终训练结果)
last_row = df.iloc[-1]
# 提取关键指标
if 'metrics/precision(B)' in df.columns:
results["精确率"] = f"{last_row['metrics/precision(B)']:.4f}"
elif 'precision' in df.columns:
results["精确率"] = f"{last_row['precision']:.4f}"
if 'metrics/recall(B)' in df.columns:
results["召回率"] = f"{last_row['metrics/recall(B)']:.4f}"
elif 'recall' in df.columns:
results["召回率"] = f"{last_row['recall']:.4f}"
if 'metrics/mAP50(B)' in df.columns:
results["mAP@0.5"] = f"{last_row['metrics/mAP50(B)']:.4f}"
elif 'mAP_0.5' in df.columns:
results["mAP@0.5"] = f"{last_row['mAP_0.5']:.4f}"
if 'metrics/mAP50-95(B)' in df.columns:
results["mAP@0.5:0.95"] = f"{last_row['metrics/mAP50-95(B)']:.4f}"
elif 'mAP_0.5:0.95' in df.columns:
results["mAP@0.5:0.95"] = f"{last_row['mAP_0.5:0.95']:.4f}"
if 'train/box_loss' in df.columns:
results["训练损失"] = f"{last_row['train/box_loss']:.4f}"
elif 'train_loss' in df.columns:
results["训练损失"] = f"{last_row['train_loss']:.4f}"
if 'val/box_loss' in df.columns:
results["验证损失"] = f"{last_row['val/box_loss']:.4f}"
elif 'val_loss' in df.columns:
results["验证损失"] = f"{last_row['val_loss']:.4f}"
# 训练轮数
if 'epoch' in df.columns:
results["完成轮数"] = f"{int(last_row['epoch']) + 1}"
except ImportError:
# 如果没有pandas,尝试手动解析CSV
with open(results_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
if len(lines) > 1: # 有标题行和数据行
headers = lines[0].strip().split(',')
last_data = lines[-1].strip().split(',')
for i, header in enumerate(headers):
if i < len(last_data):
if 'precision' in header.lower():
results["精确率"] = last_data[i]
elif 'recall' in header.lower():
results["召回率"] = last_data[i]
elif 'map50' in header.lower():
results["mAP@0.5"] = last_data[i]
except Exception as e:
pass
# 查找其他结果文件
log_files = list(model_dir.glob("*.log")) + list(model_dir.glob("*train*.txt"))
for log_file in log_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
content = f.read()
# 简单提取一些关键信息
if 'Best mAP' in content or 'best map' in content.lower():
lines = content.split('\n')
for line in lines:
if 'map' in line.lower() and ('best' in line.lower() or 'final' in line.lower()):
results["最佳mAP"] = line.strip()
break
except:
continue
except Exception as e:
pass
return results
def _evaluateModelPerformance(self, results_info):
"""评估模型性能"""
if not results_info:
return "暂无训练结果数据,无法评估模型性能。"
evaluation_parts = []
try:
# 评估mAP指标
if "mAP@0.5" in results_info:
map50 = float(results_info["mAP@0.5"])
if map50 >= 0.9:
evaluation_parts.append("检测精度: 优秀 (mAP@0.5 ≥ 0.9)")
elif map50 >= 0.8:
evaluation_parts.append("检测精度: 良好 (mAP@0.5 ≥ 0.8)")
elif map50 >= 0.7:
evaluation_parts.append("检测精度: 一般 (mAP@0.5 ≥ 0.7)")
else:
evaluation_parts.append("检测精度: 较差 (mAP@0.5 < 0.7)")
# 评估精确率和召回率
if "精确率" in results_info and "召回率" in results_info:
precision = float(results_info["精确率"])
recall = float(results_info["召回率"])
if precision >= 0.85 and recall >= 0.85:
evaluation_parts.append("平衡性能: 优秀 (精确率和召回率均 ≥ 0.85)")
elif precision >= 0.75 and recall >= 0.75:
evaluation_parts.append("平衡性能: 良好 (精确率和召回率均 ≥ 0.75)")
else:
evaluation_parts.append("平衡性能: 需要改进")
# 评估损失值
if "训练损失" in results_info and "验证损失" in results_info:
train_loss = float(results_info["训练损失"])
val_loss = float(results_info["验证损失"])
if abs(train_loss - val_loss) < 0.1:
evaluation_parts.append("模型稳定性: 良好 (训练和验证损失接近)")
elif abs(train_loss - val_loss) < 0.2:
evaluation_parts.append("模型稳定性: 一般")
else:
evaluation_parts.append("模型稳定性: 可能存在过拟合")
# 综合评估
if len(evaluation_parts) == 0:
return "数据不足,无法进行详细评估。"
# 添加使用建议
if any("优秀" in part for part in evaluation_parts):
evaluation_parts.append("\n建议: 模型表现良好,可以投入使用。")
elif any("良好" in part for part in evaluation_parts):
evaluation_parts.append("\n建议: 模型表现尚可,建议在更多数据上测试。")
else:
evaluation_parts.append("\n建议: 模型需要进一步训练优化。")
except Exception as e:
return "评估过程中出现错误,请检查训练结果数据。"
return "\n".join(evaluation_parts)
if __name__ == "__main__":
......
......@@ -580,21 +580,28 @@ class TrainingPage(QtWidgets.QWidget):
def _clearCurve(self):
"""清空曲线数据并隐藏曲线面板"""
if PYQTGRAPH_AVAILABLE and hasattr(self, 'curve_plot_widget'):
# 清空数据
try:
# 总是清空数据(无论PyQtGraph是否可用)
self.curve_data_x = []
self.curve_data_y = []
print("[曲线] 已清空曲线数据")
# 清空曲线
if self.curve_line:
self.curve_plot_widget.removeItem(self.curve_line)
self.curve_line = None
# 如果PyQtGraph可用,清空图表
if PYQTGRAPH_AVAILABLE and hasattr(self, 'curve_plot_widget') and self.curve_plot_widget:
# 清空曲线
if hasattr(self, 'curve_line') and self.curve_line:
self.curve_plot_widget.removeItem(self.curve_line)
self.curve_line = None
print("[曲线] 已清空PyQtGraph曲线")
print("[曲线] 已清空曲线数据")
# 自动隐藏曲线面板,返回到初始显示状态
self.hideCurvePanel()
print("[曲线] 曲线面板已隐藏,返回初始显示状态")
# 自动隐藏曲线面板,返回到初始显示状态
self.hideCurvePanel()
print("[曲线] 曲线面板已隐藏,返回初始显示状态")
except Exception as e:
print(f"[曲线] 清空曲线失败: {e}")
import traceback
traceback.print_exc()
def addCurvePoint(self, frame_index, height_mm):
"""添加曲线数据点
......@@ -603,39 +610,48 @@ class TrainingPage(QtWidgets.QWidget):
frame_index: 帧序号
height_mm: 液位高度(毫米)
"""
if not PYQTGRAPH_AVAILABLE or not hasattr(self, 'curve_plot_widget'):
return
try:
# 添加数据点
# 确保曲线数据列表存在
if not hasattr(self, 'curve_data_x'):
self.curve_data_x = []
if not hasattr(self, 'curve_data_y'):
self.curve_data_y = []
# 总是添加数据点到列表中(无论PyQtGraph是否可用)
self.curve_data_x.append(frame_index)
self.curve_data_y.append(height_mm)
# 如果曲线不存在,创建曲线
if self.curve_line is None:
self.curve_line = self.curve_plot_widget.plot(
self.curve_data_x,
self.curve_data_y,
pen=pg.mkPen(color='#1f77b4', width=2),
name='液位高度'
)
else:
# 更新曲线数据
self.curve_line.setData(self.curve_data_x, self.curve_data_y)
print(f"[曲线] 添加数据点: 帧{frame_index}, 液位{height_mm:.1f}mm,当前数据点数: {len(self.curve_data_x)}")
# 如果PyQtGraph可用且有curve_plot_widget,更新图表
if PYQTGRAPH_AVAILABLE and hasattr(self, 'curve_plot_widget') and self.curve_plot_widget:
# 如果曲线不存在,创建曲线
if not hasattr(self, 'curve_line') or self.curve_line is None:
self.curve_line = self.curve_plot_widget.plot(
self.curve_data_x,
self.curve_data_y,
pen=pg.mkPen(color='#1f77b4', width=2),
name='液位高度'
)
else:
# 更新曲线数据
self.curve_line.setData(self.curve_data_x, self.curve_data_y)
except Exception as e:
print(f"[曲线] 添加数据点失败: {e}")
import traceback
traceback.print_exc()
def showCurvePanel(self):
"""显示曲线面板"""
if hasattr(self, 'display_layout'):
# 切换到曲线面板(索引3:hint, display_panel, video_panel, curve_panel)
self.display_layout.setCurrentIndex(3)
if hasattr(self, 'curve_panel') and hasattr(self, 'stacked_widget'):
self.stacked_widget.setCurrentWidget(self.curve_panel)
print("[曲线] 显示曲线面板")
def hideCurvePanel(self):
"""隐藏曲线面板,返回到显示面板"""
if hasattr(self, 'display_layout'):
self.display_layout.setCurrentIndex(1) # 显示 display_panel
"""隐藏曲线面板,返回到初始显示"""
if hasattr(self, 'display_panel') and hasattr(self, 'stacked_widget'):
self.stacked_widget.setCurrentWidget(self.display_panel)
print("[曲线] 隐藏曲线面板")
def saveCurveData(self, csv_path):
"""保存曲线数据为CSV文件
......@@ -1343,95 +1359,15 @@ class TrainingPage(QtWidgets.QWidget):
# self._loadTestFileList() # 🔥 不再需要刷新测试文件列表(改用浏览方式)
def _loadBaseModelOptions(self):
"""从模型集管理页面加载基础模型选项"""
# 清空现有选项
self.base_model_combo.clear()
try:
# 尝试从父窗口获取模型集页面
if hasattr(self._parent, 'modelSetPage'):
model_set_page = self._parent.modelSetPage
# 获取模型参数字典
if hasattr(model_set_page, '_model_params'):
model_params = model_set_page._model_params
if not model_params:
self.base_model_combo.addItem("未找到模型", None)
return
# 获取默认模型
default_model = None
if hasattr(model_set_page, '_current_default_model'):
default_model = model_set_page._current_default_model
# 添加所有模型到下拉框
default_index = 0
for idx, (model_name, params) in enumerate(model_params.items()):
model_path = params.get('path', '')
# 构建显示名称
display_name = model_name
if model_name == default_model:
display_name = f"{model_name} (默认)"
default_index = idx
# 添加到下拉框,使用模型路径作为数据
self.base_model_combo.addItem(display_name, model_path)
# 设置默认选择
self.base_model_combo.setCurrentIndex(default_index)
else:
self.base_model_combo.addItem("模型集页面未初始化", None)
else:
self.base_model_combo.addItem("未找到模型集页面", None)
except Exception as e:
print(f"[基础模型] 加载失败: {e}")
import traceback
traceback.print_exc()
self.base_model_combo.addItem("加载失败", None)
"""从detection_model目录加载基础模型选项"""
# 使用统一的模型刷新方法
self.refreshModelLists()
def _loadTestModelOptions(self):
"""加载测试模型选项(从 train_model 文件夹读取)"""
# 清空现有选项
self.test_model_combo.clear()
try:
from ...database.config import get_project_root
project_root = get_project_root()
except ImportError as e:
# 如果导入失败,使用相对路径
project_root = Path(__file__).parent.parent.parent
# 🔥 修改:只从 train_model 目录扫描模型
all_models = self._scanDetectionModelDirectory(project_root)
# 添加到下拉框
if not all_models:
self.test_model_combo.addItem("未找到测试模型")
return
# 获取记忆的测试模型路径
remembered_model = self._getRememberedTestModel(project_root)
default_index = 0
# 添加所有模型到下拉框
for idx, model in enumerate(all_models):
display_name = model['name']
model_path = model['path']
# 如果找到记忆的模型,设置为默认选择
if remembered_model and model_path == remembered_model:
default_index = idx
display_name = f"{display_name} (上次使用)"
self.test_model_combo.addItem(display_name, model_path)
# 设置默认选择
self.test_model_combo.setCurrentIndex(default_index)
"""从detection_model目录加载测试模型选项"""
# 测试模型和基础模型使用相同的数据源,无需单独加载
# refreshModelLists() 已经处理了测试模型下拉菜单
pass
def _loadModelsFromConfig(self, project_root):
"""从配置文件加载通道模型"""
......@@ -1823,8 +1759,8 @@ class TrainingPage(QtWidgets.QWidget):
self._showNoCurveMessage()
return
# 切换到曲线面板显示
self.showCurvePanel()
# 在当前测试页面中显示曲线,而不是切换面板
self._showCurveInTestPage()
# 显示曲线信息提示
data_count = len(self.curve_data_x)
......@@ -1835,7 +1771,7 @@ class TrainingPage(QtWidgets.QWidget):
self,
"曲线信息",
f"图片测试结果:\n液位高度: {liquid_level:.1f} mm\n\n"
f"曲线已显示在左侧面板中。"
f"曲线已显示在左侧测试页面中。"
)
else:
# 视频测试
......@@ -1849,7 +1785,7 @@ class TrainingPage(QtWidgets.QWidget):
f"数据点数: {data_count} 个\n"
f"液位范围: {min_level:.1f} - {max_level:.1f} mm\n"
f"平均液位: {avg_level:.1f} mm\n\n"
f"曲线已显示在左侧面板中。"
f"曲线已显示在左侧测试页面中。"
)
except Exception as e:
......@@ -1860,6 +1796,356 @@ class TrainingPage(QtWidgets.QWidget):
f"显示曲线时发生错误:\n{str(e)}"
)
def _showCurveInTestPage(self):
"""在测试页面中显示曲线"""
try:
print(f"[曲线显示] 开始显示曲线,PyQtGraph可用: {PYQTGRAPH_AVAILABLE}")
# 检查曲线数据
if not hasattr(self, 'curve_data_x') or not hasattr(self, 'curve_data_y'):
print("[曲线显示] 错误: 缺少曲线数据属性")
self._showCurveAsText()
return
if len(self.curve_data_x) == 0 or len(self.curve_data_y) == 0:
print(f"[曲线显示] 错误: 曲线数据为空,X数据点: {len(self.curve_data_x)}, Y数据点: {len(self.curve_data_y)}")
self._showCurveAsText()
return
print(f"[曲线显示] 曲线数据检查通过,数据点数: {len(self.curve_data_x)}")
if not PYQTGRAPH_AVAILABLE:
print("[曲线显示] PyQtGraph不可用,使用文本显示")
self._showCurveAsText()
return
# 生成曲线图表的HTML内容
print("[曲线显示] 开始生成HTML内容")
curve_html = self._generateCurveHTML()
if not curve_html:
print("[曲线显示] 错误: HTML内容生成失败")
self._showCurveAsText()
return
# 在显示面板中显示曲线HTML
if hasattr(self, 'display_panel') and hasattr(self, 'display_layout'):
self.display_panel.setHtml(curve_html)
self.display_layout.setCurrentWidget(self.display_panel)
print("[曲线显示] 曲线已显示在测试页面中")
else:
print("[曲线显示] 错误: 缺少display_panel或display_layout属性")
self._showCurveAsText()
except Exception as e:
print(f"[曲线显示] 在测试页面显示曲线失败: {e}")
import traceback
traceback.print_exc()
# 降级到文本显示
self._showCurveAsText()
def _generateCurveHTML(self):
"""生成曲线的HTML内容"""
try:
print("[曲线HTML] 开始生成HTML内容")
# 保存曲线图片到临时文件
import tempfile
import os
temp_dir = tempfile.gettempdir()
curve_image_path = os.path.join(temp_dir, "test_curve_display.png")
print(f"[曲线HTML] 临时图片路径: {curve_image_path}")
# 优先尝试使用matplotlib生成曲线(更可靠)
if self._createMatplotlibCurve(curve_image_path):
print("[曲线HTML] matplotlib曲线生成成功")
# 使用PyQtGraph导出曲线图片
elif hasattr(self, 'curve_plot_widget') and self.curve_plot_widget:
print("[曲线HTML] 使用现有的curve_plot_widget导出图片")
try:
# 使用样式管理器的配置
from widgets.style_manager import CurveDisplayStyleManager
chart_width, chart_height = CurveDisplayStyleManager.getChartSize()
exporter = pg.exporters.ImageExporter(self.curve_plot_widget.plotItem)
exporter.parameters()['width'] = chart_width
exporter.parameters()['height'] = chart_height
exporter.export(curve_image_path)
print("[曲线HTML] 现有widget图片导出成功")
except Exception as e:
print(f"[曲线HTML] 现有widget图片导出失败: {e}")
self._createTempCurvePlot(curve_image_path)
else:
print("[曲线HTML] 没有现有的curve_plot_widget,创建临时plot")
# 如果没有现有的plot widget,创建一个临时的
self._createTempCurvePlot(curve_image_path)
# 检查图片是否成功生成
if not os.path.exists(curve_image_path):
print(f"[曲线HTML] 错误: 图片文件未生成: {curve_image_path}")
return self._getFallbackCurveHTML()
print(f"[曲线HTML] 图片生成成功,文件大小: {os.path.getsize(curve_image_path)} bytes")
# 生成统计信息
data_count = len(self.curve_data_x)
stats_html = ""
if data_count == 1:
# 图片测试
liquid_level = self.curve_data_y[0]
stats_html = f"""
<div style="margin-bottom: 15px; padding: 10px; background: #e8f4fd; border: 1px solid #bee5eb; border-radius: 5px;">
<h4 style="margin: 0 0 8px 0; color: #0c5460;">图片测试结果</h4>
<p style="margin: 0; color: #0c5460;"><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
</div>
"""
else:
# 视频测试
min_level = min(self.curve_data_y)
max_level = max(self.curve_data_y)
avg_level = sum(self.curve_data_y) / len(self.curve_data_y)
stats_html = f"""
<div style="margin-bottom: 15px; padding: 10px; background: #e8f4fd; border: 1px solid #bee5eb; border-radius: 5px;">
<h4 style="margin: 0 0 8px 0; color: #0c5460;">视频测试结果统计</h4>
<p style="margin: 2px 0; color: #0c5460;"><strong>数据点数:</strong> {data_count} 个</p>
<p style="margin: 2px 0; color: #0c5460;"><strong>液位范围:</strong> {min_level:.1f} - {max_level:.1f} mm</p>
<p style="margin: 2px 0; color: #0c5460;"><strong>平均液位:</strong> {avg_level:.1f} mm</p>
</div>
"""
# 使用统一的样式管理器生成HTML内容
from widgets.style_manager import CurveDisplayStyleManager
html_content = CurveDisplayStyleManager.generateCurveHTML(curve_image_path, stats_html)
return html_content
except Exception as e:
print(f"[曲线HTML] 生成HTML失败: {e}")
return self._getFallbackCurveHTML()
def _createMatplotlibCurve(self, output_path):
"""使用matplotlib创建曲线图"""
try:
print("[matplotlib曲线] 开始使用matplotlib生成曲线")
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端
# 验证数据
if not hasattr(self, 'curve_data_x') or not hasattr(self, 'curve_data_y'):
print("[matplotlib曲线] 错误: 缺少曲线数据")
return False
if len(self.curve_data_x) == 0 or len(self.curve_data_y) == 0:
print(f"[matplotlib曲线] 错误: 曲线数据为空")
return False
print(f"[matplotlib曲线] 数据验证通过,X点数: {len(self.curve_data_x)}, Y点数: {len(self.curve_data_y)}")
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['Microsoft YaHei', 'SimHei', 'Arial']
plt.rcParams['axes.unicode_minus'] = False
# 使用样式管理器的配置
from widgets.style_manager import CurveDisplayStyleManager
chart_width, chart_height = CurveDisplayStyleManager.getChartSize()
chart_dpi = CurveDisplayStyleManager.getChartDPI()
bg_color = CurveDisplayStyleManager.getPlotBackgroundColor()
# 创建图形,使用统一的尺寸和DPI
fig, ax = plt.subplots(figsize=(chart_width/100, chart_height/100), dpi=chart_dpi)
# 绘制曲线
ax.plot(self.curve_data_x, self.curve_data_y, 'b-', linewidth=2.5, marker='o', markersize=5,
markerfacecolor='white', markeredgecolor='blue', markeredgewidth=1.5)
# 设置标签和标题
ax.set_xlabel('帧序号', fontsize=12)
ax.set_ylabel('液位高度 (mm)', fontsize=12)
ax.set_title('液位检测曲线', fontsize=14, fontweight='bold', pad=20)
# 设置网格
ax.grid(True, alpha=0.3, linestyle='--')
# 设置背景色,使用样式管理器的颜色
ax.set_facecolor(bg_color)
fig.patch.set_facecolor(bg_color)
# 优化布局,减少边距
plt.tight_layout(pad=1.0)
# 保存图片,优化参数
plt.savefig(output_path,
bbox_inches='tight',
pad_inches=0.2,
facecolor=bg_color,
edgecolor='none',
dpi=100)
plt.close(fig)
print(f"[matplotlib曲线] 曲线图片已保存: {output_path}")
return True
except Exception as e:
print(f"[matplotlib曲线] 创建失败: {e}")
import traceback
traceback.print_exc()
return False
def _createTempCurvePlot(self, output_path):
"""创建临时曲线图并保存"""
try:
print("[临时曲线] 开始创建临时曲线图")
import pyqtgraph as pg
# 验证数据
if not hasattr(self, 'curve_data_x') or not hasattr(self, 'curve_data_y'):
print("[临时曲线] 错误: 缺少曲线数据")
self._createPlaceholderImage(output_path)
return
if len(self.curve_data_x) == 0 or len(self.curve_data_y) == 0:
print(f"[临时曲线] 错误: 曲线数据为空")
self._createPlaceholderImage(output_path)
return
print(f"[临时曲线] 数据验证通过,X点数: {len(self.curve_data_x)}, Y点数: {len(self.curve_data_y)}")
# 创建临时的plot widget
temp_plot = pg.PlotWidget()
temp_plot.setBackground('#f8f9fa')
temp_plot.showGrid(x=True, y=True, alpha=0.3)
temp_plot.setLabel('left', '液位高度', units='mm')
temp_plot.setLabel('bottom', '帧序号')
temp_plot.setTitle('液位检测曲线', color='#495057', size='12pt')
print("[临时曲线] PlotWidget创建成功,开始绘制曲线")
# 绘制曲线
temp_plot.plot(
self.curve_data_x,
self.curve_data_y,
pen=pg.mkPen(color='#1f77b4', width=2),
name='液位高度'
)
print("[临时曲线] 曲线绘制完成,开始导出图片")
# 使用样式管理器的配置导出图片
from widgets.style_manager import CurveDisplayStyleManager
chart_width, chart_height = CurveDisplayStyleManager.getChartSize()
bg_color = CurveDisplayStyleManager.getPlotBackgroundColor()
# 导出图片(使用统一尺寸)
exporter = pg.exporters.ImageExporter(temp_plot.plotItem)
exporter.parameters()['width'] = chart_width
exporter.parameters()['height'] = chart_height
# 设置背景色
temp_plot.setBackground(bg_color)
exporter.export(output_path)
print(f"[临时曲线] 曲线图片已保存: {output_path}")
except Exception as e:
print(f"[临时曲线] 创建失败: {e}")
import traceback
traceback.print_exc()
# 创建一个简单的占位图片
self._createPlaceholderImage(output_path)
def _createPlaceholderImage(self, output_path):
"""创建占位图片"""
try:
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import numpy as np
# 如果有曲线数据,尝试用matplotlib绘制
if hasattr(self, 'curve_data_x') and hasattr(self, 'curve_data_y') and len(self.curve_data_x) > 0:
print("[占位图片] 尝试使用matplotlib绘制曲线")
try:
plt.figure(figsize=(10, 5), dpi=80)
plt.plot(self.curve_data_x, self.curve_data_y, 'b-', linewidth=2, marker='o', markersize=4)
plt.xlabel('帧序号')
plt.ylabel('液位高度 (mm)')
plt.title('液位检测曲线')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, bbox_inches='tight', pad_inches=0.1, facecolor='#f8f9fa')
plt.close()
print(f"[占位图片] matplotlib曲线已保存: {output_path}")
return
except Exception as e:
print(f"[占位图片] matplotlib绘制失败: {e}")
# 创建简化的占位图片(更小的尺寸,减少白色背景)
img = Image.new('RGB', (600, 300), '#f8f9fa')
draw = ImageDraw.Draw(img)
# 绘制边框
draw.rectangle([0, 0, 599, 299], outline='#dee2e6', width=1)
# 绘制文本
try:
font = ImageFont.truetype("C:/Windows/Fonts/msyh.ttc", 18)
except:
font = ImageFont.load_default()
text = "曲线生成失败,请检查数据"
text_bbox = draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
x = (600 - text_width) // 2
y = (300 - text_height) // 2
draw.text((x, y), text, fill='#666666', font=font)
img.save(output_path)
print(f"[占位图片] 已创建: {output_path}")
except Exception as e:
print(f"[占位图片] 创建失败: {e}")
def _getFallbackCurveHTML(self):
"""获取降级的曲线HTML内容"""
data_count = len(self.curve_data_x)
if data_count == 1:
liquid_level = self.curve_data_y[0]
return f"""
<div style="font-family: Arial, sans-serif; padding: 20px; background: #ffffff; color: #333333;">
<h3>图片测试结果</h3>
<p><strong>液位高度:</strong> {liquid_level:.1f} mm</p>
<p style="color: #666; font-size: 12px;">注: 曲线图表生成失败,显示文本结果。</p>
</div>
"""
else:
min_level = min(self.curve_data_y)
max_level = max(self.curve_data_y)
avg_level = sum(self.curve_data_y) / len(self.curve_data_y)
return f"""
<div style="font-family: Arial, sans-serif; padding: 20px; background: #ffffff; color: #333333;">
<h3>视频测试结果统计</h3>
<p><strong>数据点数:</strong> {data_count} 个</p>
<p><strong>液位范围:</strong> {min_level:.1f} - {max_level:.1f} mm</p>
<p><strong>平均液位:</strong> {avg_level:.1f} mm</p>
<p style="color: #666; font-size: 12px;">注: 曲线图表生成失败,显示文本结果。</p>
</div>
"""
def _showCurveAsText(self):
"""以文本形式显示曲线结果"""
try:
fallback_html = self._getFallbackCurveHTML()
if hasattr(self, 'display_panel') and hasattr(self, 'display_layout'):
self.display_panel.setHtml(fallback_html)
self.display_layout.setCurrentWidget(self.display_panel)
print("[曲线显示] 以文本形式显示曲线结果")
except Exception as e:
print(f"[曲线显示] 文本显示也失败: {e}")
def _showNoCurveMessage(self):
"""显示无曲线数据的提示"""
QtWidgets.QMessageBox.information(
......@@ -1874,6 +2160,25 @@ class TrainingPage(QtWidgets.QWidget):
"测试完成后即可查看曲线结果。"
)
def _testCurveGeneration(self):
"""测试曲线生成功能(调试用)"""
try:
print("[曲线测试] 开始测试曲线生成功能")
# 创建测试数据
self.curve_data_x = [0, 1, 2, 3, 4]
self.curve_data_y = [25.0, 26.5, 24.8, 27.2, 25.9]
print(f"[曲线测试] 测试数据创建完成,X: {self.curve_data_x}, Y: {self.curve_data_y}")
# 测试曲线显示
self._showCurveInTestPage()
except Exception as e:
print(f"[曲线测试] 测试失败: {e}")
import traceback
traceback.print_exc()
def enableViewCurveButton(self):
"""启用查看曲线按钮(测试完成后调用)"""
try:
......@@ -1915,6 +2220,56 @@ class TrainingPage(QtWidgets.QWidget):
elif self.checkbox_template_3.isChecked():
return "template_3"
return None
def refreshModelLists(self):
"""刷新模型下拉菜单列表"""
try:
# 导入模型集页面以获取模型列表
from .modelset_page import ModelSetPage
# 获取detection_model目录下的所有模型
models = ModelSetPage.getDetectionModels()
# 刷新基础模型下拉菜单
if hasattr(self, 'base_model_combo'):
current_base = self.base_model_combo.currentText()
self.base_model_combo.clear()
for model in models:
display_name = model['name'] # 只显示模型名称,不显示文件大小
self.base_model_combo.addItem(display_name, model['path'])
# 尝试恢复之前的选择
if current_base:
index = self.base_model_combo.findText(current_base)
if index >= 0:
self.base_model_combo.setCurrentIndex(index)
# 刷新测试模型下拉菜单
if hasattr(self, 'test_model_combo'):
current_test = self.test_model_combo.currentText()
self.test_model_combo.clear()
for model in models:
display_name = model['name'] # 只显示模型名称,不显示文件大小
self.test_model_combo.addItem(display_name, model['path'])
# 尝试恢复之前的选择
if current_test:
index = self.test_model_combo.findText(current_test)
if index >= 0:
self.test_model_combo.setCurrentIndex(index)
print(f"[模型同步] 已刷新模型列表,共 {len(models)} 个模型")
except Exception as e:
print(f"[错误] 刷新模型列表失败: {e}")
import traceback
traceback.print_exc()
def refreshBaseModelList(self):
"""刷新基础模型下拉菜单(兼容旧接口)"""
self.refreshModelLists()
# 测试代码
......@@ -2073,3 +2428,4 @@ class TrainingNotesDialog(QtWidgets.QDialog):
def getNotesContent(self):
"""获取笔记内容"""
return self.text_edit.toPlainText().strip()
......@@ -941,8 +941,8 @@ class BackgroundStyleManager:
"""全局背景颜色管理器"""
# 全局背景颜色配置
GLOBAL_BACKGROUND_COLOR = "white" # 统一白色背景
GLOBAL_BACKGROUND_STYLE = "background-color: white;"
GLOBAL_BACKGROUND_COLOR = "#f8f9fa" # 统一浅灰色背景,消除纯白色
GLOBAL_BACKGROUND_STYLE = "background-color: #f8f9fa;"
@classmethod
def applyToWidget(cls, widget):
......@@ -1208,6 +1208,112 @@ def setMenuBarHoverColor(color):
# ============================================================================
# 曲线显示样式管理器 (Curve Display Style Manager)
# ============================================================================
class CurveDisplayStyleManager:
"""曲线显示样式管理器"""
# 统一的颜色配置
BACKGROUND_COLOR = "#f8f9fa" # 统一背景色
PLOT_BACKGROUND_COLOR = "#f8f9fa" # 图表背景色
CONTAINER_BACKGROUND_COLOR = "#ffffff" # 容器背景色
BORDER_COLOR = "#dee2e6" # 边框颜色
TEXT_COLOR = "#333333" # 文本颜色
# 图表尺寸配置
CHART_WIDTH = 600
CHART_HEIGHT = 350
CHART_DPI = 100
@classmethod
def getBackgroundColor(cls):
"""获取统一背景色"""
return cls.BACKGROUND_COLOR
@classmethod
def getPlotBackgroundColor(cls):
"""获取图表背景色"""
return cls.PLOT_BACKGROUND_COLOR
@classmethod
def getContainerBackgroundColor(cls):
"""获取容器背景色"""
return cls.CONTAINER_BACKGROUND_COLOR
@classmethod
def getBorderColor(cls):
"""获取边框颜色"""
return cls.BORDER_COLOR
@classmethod
def getTextColor(cls):
"""获取文本颜色"""
return cls.TEXT_COLOR
@classmethod
def getChartSize(cls):
"""获取图表尺寸"""
return cls.CHART_WIDTH, cls.CHART_HEIGHT
@classmethod
def getChartDPI(cls):
"""获取图表DPI"""
return cls.CHART_DPI
@classmethod
def generateCurveHTML(cls, curve_image_path, stats_html=""):
"""生成统一样式的曲线HTML"""
return f"""
<div style="font-family: 'Microsoft YaHei', 'SimHei', Arial, sans-serif; padding: 20px; background: {cls.BACKGROUND_COLOR}; color: {cls.TEXT_COLOR};">
<div style="margin-bottom: 20px;">
<h3 style="margin: 0 0 15px 0; color: {cls.TEXT_COLOR}; font-size: 18px; font-weight: 600;">液位检测曲线结果</h3>
</div>
{stats_html}
<div style="text-align: center; margin-bottom: 15px; background: {cls.CONTAINER_BACKGROUND_COLOR}; padding: 8px; border-radius: 6px; border: 1px solid {cls.BORDER_COLOR};">
<img src="file:///{curve_image_path.replace(chr(92), '/')}" style="max-width: 100%; height: auto; border-radius: 4px;">
</div>
<div style="padding: 10px; background: {cls.CONTAINER_BACKGROUND_COLOR}; border: 1px solid {cls.BORDER_COLOR}; border-radius: 5px; font-size: 12px; color: #666;">
<p style="margin: 0;"><strong>说明:</strong> 曲线显示了液位检测的结果变化趋势。X轴表示帧序号,Y轴表示液位高度(毫米)。</p>
</div>
</div>
"""
@classmethod
def setBackgroundColor(cls, color):
"""设置背景颜色"""
cls.BACKGROUND_COLOR = color
cls.PLOT_BACKGROUND_COLOR = color
@classmethod
def setChartSize(cls, width, height):
"""设置图表尺寸"""
cls.CHART_WIDTH = width
cls.CHART_HEIGHT = height
# 曲线显示样式管理器便捷函数
def getCurveBackgroundColor():
"""获取曲线背景色的便捷函数"""
return CurveDisplayStyleManager.getBackgroundColor()
def getCurvePlotBackgroundColor():
"""获取曲线图表背景色的便捷函数"""
return CurveDisplayStyleManager.getPlotBackgroundColor()
def getCurveChartSize():
"""获取曲线图表尺寸的便捷函数"""
return CurveDisplayStyleManager.getChartSize()
def generateCurveHTML(curve_image_path, stats_html=""):
"""生成曲线HTML的便捷函数"""
return CurveDisplayStyleManager.generateCurveHTML(curve_image_path, stats_html)
# ============================================================================
# 测试代码
# ============================================================================
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment