BiBiER / utils /search_utils.py
farbverlauf's picture
gpu
960b1a0
# coding: utf-8
import copy
import logging
import numpy as np
import datetime
from itertools import product
from typing import Any
def format_result_box_dual(step_num, param_name, candidate, fixed_params, dev_metrics, test_metrics, is_best=False):
title = f"Шаг {step_num}: {param_name} = {candidate}"
fixed_lines = [f"{k} = {v}" for k, v in fixed_params.items()]
def format_metrics_block(metrics, label):
lines = [f" Результаты ({label.upper()}):"]
for k in ["uar", "war", "mf1", "wf1", "loss", "mean"]:
if k in metrics:
val = metrics[k]
line = f" {k.upper():12} = {val:.4f}" if isinstance(val, float) else f" {k.upper():12} = {val}"
if is_best and label.lower() == "dev" and k.lower() == "mean":
line += " ✅"
lines.append(line)
return lines
content_lines = [title, " Фиксировано:"]
content_lines += [f" {line}" for line in fixed_lines]
# DEV блок
content_lines += format_metrics_block(dev_metrics, "dev")
content_lines.append("")
# TEST блок
content_lines += format_metrics_block(test_metrics, "test")
# GAP
if "mean" in dev_metrics and "mean" in test_metrics:
gap_val = dev_metrics["mean"] - test_metrics["mean"]
gap_str = f" GAP = {gap_val:+.4f}"
content_lines.append(gap_str)
max_width = max(len(line) for line in content_lines)
border_top = "┌" + "─" * (max_width + 2) + "┐"
border_bot = "└" + "─" * (max_width + 2) + "┘"
box = [border_top]
for line in content_lines:
box.append(f"│ {line.ljust(max_width)} │")
box.append(border_bot)
return "\n".join(box)
def greedy_search(
base_config,
train_loader,
dev_loader,
test_loader,
train_fn,
overrides_file: str,
param_grid: dict[str, list],
default_values: dict[str, Any],
csv_prefix: str = None
):
current_best_params = copy.deepcopy(default_values)
all_param_names = list(param_grid.keys())
model_name = getattr(base_config, "model_name", "UNKNOWN_MODEL")
with open(overrides_file, "a", encoding="utf-8") as f:
f.write("=== Жадный (поэтапный) перебор гиперпараметров (Dev-based) ===\n")
f.write(f"Модель: {model_name}\n")
for i, param_name in enumerate(all_param_names):
candidates = param_grid[param_name]
tried_value = current_best_params[param_name]
if i == 0:
candidates_to_try = candidates
else:
candidates_to_try = [v for v in candidates if v != tried_value]
best_val_for_param = tried_value
best_metric_for_param = float("-inf")
# Если не первый шаг — вставим текущую комбу
if i != 0:
config_default = copy.deepcopy(base_config)
for k, v in current_best_params.items():
setattr(config_default, k, v)
logging.info(f"[ШАГ {i+1}] {param_name} = {tried_value} (ранее проверенный)")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
csv_filename = f"{csv_prefix}_{model_name}_{param_name}_{tried_value}_{timestamp}.csv" if csv_prefix else None
dev_mean_default, dev_metrics_default, test_metrics_default = train_fn(
config_default,
train_loader,
dev_loader,
test_loader,
metrics_csv_path=csv_filename
)
box_text = format_result_box_dual(
step_num=i+1,
param_name=param_name,
candidate=tried_value,
fixed_params={k: v for k, v in current_best_params.items() if k != param_name},
dev_metrics=dev_metrics_default,
test_metrics=test_metrics_default,
is_best=True
)
with open(overrides_file, "a", encoding="utf-8") as f:
f.write("\n" + box_text + "\n")
_log_dataset_metrics(dev_metrics_default, overrides_file, label="dev")
_log_dataset_metrics(test_metrics_default, overrides_file, label="test")
best_metric_for_param = dev_mean_default
for candidate in candidates_to_try:
config = copy.deepcopy(base_config)
for k, v in current_best_params.items():
setattr(config, k, v)
setattr(config, param_name, candidate)
logging.info(f"[ШАГ {i+1}] {param_name} = {candidate}, (остальные {current_best_params})")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
csv_filename = f"{csv_prefix}_{model_name}_{param_name}_{candidate}_{timestamp}.csv" if csv_prefix else None
dev_mean, dev_metrics, test_metrics = train_fn(
config,
train_loader,
dev_loader,
test_loader,
metrics_csv_path=csv_filename
)
is_better = dev_mean > best_metric_for_param
box_text = format_result_box_dual(
step_num=i+1,
param_name=param_name,
candidate=candidate,
fixed_params={k: v for k, v in current_best_params.items() if k != param_name},
dev_metrics=dev_metrics,
test_metrics=test_metrics,
is_best=is_better
)
with open(overrides_file, "a", encoding="utf-8") as f:
f.write("\n" + box_text + "\n")
_log_dataset_metrics(dev_metrics, overrides_file, label="dev")
_log_dataset_metrics(test_metrics, overrides_file, label="test")
if is_better:
best_val_for_param = candidate
best_metric_for_param = dev_mean
current_best_params[param_name] = best_val_for_param
with open(overrides_file, "a", encoding="utf-8") as f:
f.write(f"\n>> [Итог Шаг{i+1}]: Лучший {param_name}={best_val_for_param}, dev_mean={best_metric_for_param:.4f}\n")
with open(overrides_file, "a", encoding="utf-8") as f:
f.write("\n=== Итоговая комбинация (Dev-based) ===\n")
for k, v in current_best_params.items():
f.write(f"{k} = {v}\n")
logging.info("Готово! Лучшие параметры подобраны.")
def exhaustive_search(
base_config,
train_loader,
dev_loader,
test_loader,
train_fn,
overrides_file: str,
param_grid: dict[str, list],
csv_prefix: str = None
):
all_param_names = list(param_grid.keys())
model_name = getattr(base_config, "model_name", "UNKNOWN_MODEL")
with open(overrides_file, "a", encoding="utf-8") as f:
f.write("=== Полный перебор гиперпараметров (Dev-based) ===\n")
f.write(f"Модель: {model_name}\n")
best_config = None
best_metric = float("-inf")
best_metrics = {}
combo_id = 0
for combo in product(*(param_grid[param] for param in all_param_names)):
combo_id += 1
param_combo = dict(zip(all_param_names, combo))
config = copy.deepcopy(base_config)
for k, v in param_combo.items():
setattr(config, k, v)
logging.info(f"\n[Комбинация #{combo_id}] {param_combo}")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
csv_filename = f"{csv_prefix}_{model_name}_combo{combo_id}_{timestamp}.csv" if csv_prefix else None
dev_mean, dev_metrics, test_metrics = train_fn(
config,
train_loader,
dev_loader,
test_loader,
metrics_csv_path=csv_filename
)
is_better = dev_mean > best_metric
box_text = format_result_box_dual(
step_num=combo_id,
param_name=" + ".join(all_param_names),
candidate=str(combo),
fixed_params={},
dev_metrics=dev_metrics,
test_metrics=test_metrics,
is_best=is_better
)
with open(overrides_file, "a", encoding="utf-8") as f:
f.write("\n" + box_text + "\n")
_log_dataset_metrics(dev_metrics, overrides_file, label="dev")
_log_dataset_metrics(test_metrics, overrides_file, label="test")
if is_better:
best_metric = dev_mean
best_config = param_combo
best_metrics = dev_metrics
with open(overrides_file, "a", encoding="utf-8") as f:
f.write("\n=== Лучшая комбинация (Dev-based) ===\n")
for k, v in best_config.items():
f.write(f"{k} = {v}\n")
logging.info("Полный перебор завершён! Лучшие параметры выбраны.")
return best_metric, best_config, best_metrics
def _compute_combined_avg(dev_metrics):
if "by_dataset" not in dev_metrics:
return None
values = []
for entry in dev_metrics["by_dataset"]:
for key in ["uar", "war", "mf1", "wf1"]:
if key in entry:
values.append(entry[key])
return float(np.mean(values)) if values else None
def _log_dataset_metrics(metrics, file_path, label="dev"):
if "by_dataset" not in metrics:
return
with open(file_path, "a", encoding="utf-8") as f:
f.write(f"\n>>> Подробные метрики по каждому датасету ({label}):\n")
for ds in metrics["by_dataset"]:
name = ds.get("name", "unknown")
f.write(f" - {name}:\n")
for k in ["loss", "uar", "war", "mf1", "wf1", "mean"]:
if k in ds:
f.write(f" {k.upper():4} = {ds[k]:.4f}\n")
f.write(f"<<< Конец подробных метрик ({label})\n")