Spaces:
Running
Running
# coding: utf-8 | |
import copy | |
import logging | |
import numpy as np | |
import datetime | |
from itertools import product | |
from typing import Any | |
def format_result_box_dual(step_num, param_name, candidate, fixed_params, dev_metrics, test_metrics, is_best=False): | |
title = f"Шаг {step_num}: {param_name} = {candidate}" | |
fixed_lines = [f"{k} = {v}" for k, v in fixed_params.items()] | |
def format_metrics_block(metrics, label): | |
lines = [f" Результаты ({label.upper()}):"] | |
for k in ["uar", "war", "mf1", "wf1", "loss", "mean"]: | |
if k in metrics: | |
val = metrics[k] | |
line = f" {k.upper():12} = {val:.4f}" if isinstance(val, float) else f" {k.upper():12} = {val}" | |
if is_best and label.lower() == "dev" and k.lower() == "mean": | |
line += " ✅" | |
lines.append(line) | |
return lines | |
content_lines = [title, " Фиксировано:"] | |
content_lines += [f" {line}" for line in fixed_lines] | |
# DEV блок | |
content_lines += format_metrics_block(dev_metrics, "dev") | |
content_lines.append("") | |
# TEST блок | |
content_lines += format_metrics_block(test_metrics, "test") | |
# GAP | |
if "mean" in dev_metrics and "mean" in test_metrics: | |
gap_val = dev_metrics["mean"] - test_metrics["mean"] | |
gap_str = f" GAP = {gap_val:+.4f}" | |
content_lines.append(gap_str) | |
max_width = max(len(line) for line in content_lines) | |
border_top = "┌" + "─" * (max_width + 2) + "┐" | |
border_bot = "└" + "─" * (max_width + 2) + "┘" | |
box = [border_top] | |
for line in content_lines: | |
box.append(f"│ {line.ljust(max_width)} │") | |
box.append(border_bot) | |
return "\n".join(box) | |
def greedy_search( | |
base_config, | |
train_loader, | |
dev_loader, | |
test_loader, | |
train_fn, | |
overrides_file: str, | |
param_grid: dict[str, list], | |
default_values: dict[str, Any], | |
csv_prefix: str = None | |
): | |
current_best_params = copy.deepcopy(default_values) | |
all_param_names = list(param_grid.keys()) | |
model_name = getattr(base_config, "model_name", "UNKNOWN_MODEL") | |
with open(overrides_file, "a", encoding="utf-8") as f: | |
f.write("=== Жадный (поэтапный) перебор гиперпараметров (Dev-based) ===\n") | |
f.write(f"Модель: {model_name}\n") | |
for i, param_name in enumerate(all_param_names): | |
candidates = param_grid[param_name] | |
tried_value = current_best_params[param_name] | |
if i == 0: | |
candidates_to_try = candidates | |
else: | |
candidates_to_try = [v for v in candidates if v != tried_value] | |
best_val_for_param = tried_value | |
best_metric_for_param = float("-inf") | |
# Если не первый шаг — вставим текущую комбу | |
if i != 0: | |
config_default = copy.deepcopy(base_config) | |
for k, v in current_best_params.items(): | |
setattr(config_default, k, v) | |
logging.info(f"[ШАГ {i+1}] {param_name} = {tried_value} (ранее проверенный)") | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
csv_filename = f"{csv_prefix}_{model_name}_{param_name}_{tried_value}_{timestamp}.csv" if csv_prefix else None | |
dev_mean_default, dev_metrics_default, test_metrics_default = train_fn( | |
config_default, | |
train_loader, | |
dev_loader, | |
test_loader, | |
metrics_csv_path=csv_filename | |
) | |
box_text = format_result_box_dual( | |
step_num=i+1, | |
param_name=param_name, | |
candidate=tried_value, | |
fixed_params={k: v for k, v in current_best_params.items() if k != param_name}, | |
dev_metrics=dev_metrics_default, | |
test_metrics=test_metrics_default, | |
is_best=True | |
) | |
with open(overrides_file, "a", encoding="utf-8") as f: | |
f.write("\n" + box_text + "\n") | |
_log_dataset_metrics(dev_metrics_default, overrides_file, label="dev") | |
_log_dataset_metrics(test_metrics_default, overrides_file, label="test") | |
best_metric_for_param = dev_mean_default | |
for candidate in candidates_to_try: | |
config = copy.deepcopy(base_config) | |
for k, v in current_best_params.items(): | |
setattr(config, k, v) | |
setattr(config, param_name, candidate) | |
logging.info(f"[ШАГ {i+1}] {param_name} = {candidate}, (остальные {current_best_params})") | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
csv_filename = f"{csv_prefix}_{model_name}_{param_name}_{candidate}_{timestamp}.csv" if csv_prefix else None | |
dev_mean, dev_metrics, test_metrics = train_fn( | |
config, | |
train_loader, | |
dev_loader, | |
test_loader, | |
metrics_csv_path=csv_filename | |
) | |
is_better = dev_mean > best_metric_for_param | |
box_text = format_result_box_dual( | |
step_num=i+1, | |
param_name=param_name, | |
candidate=candidate, | |
fixed_params={k: v for k, v in current_best_params.items() if k != param_name}, | |
dev_metrics=dev_metrics, | |
test_metrics=test_metrics, | |
is_best=is_better | |
) | |
with open(overrides_file, "a", encoding="utf-8") as f: | |
f.write("\n" + box_text + "\n") | |
_log_dataset_metrics(dev_metrics, overrides_file, label="dev") | |
_log_dataset_metrics(test_metrics, overrides_file, label="test") | |
if is_better: | |
best_val_for_param = candidate | |
best_metric_for_param = dev_mean | |
current_best_params[param_name] = best_val_for_param | |
with open(overrides_file, "a", encoding="utf-8") as f: | |
f.write(f"\n>> [Итог Шаг{i+1}]: Лучший {param_name}={best_val_for_param}, dev_mean={best_metric_for_param:.4f}\n") | |
with open(overrides_file, "a", encoding="utf-8") as f: | |
f.write("\n=== Итоговая комбинация (Dev-based) ===\n") | |
for k, v in current_best_params.items(): | |
f.write(f"{k} = {v}\n") | |
logging.info("Готово! Лучшие параметры подобраны.") | |
def exhaustive_search( | |
base_config, | |
train_loader, | |
dev_loader, | |
test_loader, | |
train_fn, | |
overrides_file: str, | |
param_grid: dict[str, list], | |
csv_prefix: str = None | |
): | |
all_param_names = list(param_grid.keys()) | |
model_name = getattr(base_config, "model_name", "UNKNOWN_MODEL") | |
with open(overrides_file, "a", encoding="utf-8") as f: | |
f.write("=== Полный перебор гиперпараметров (Dev-based) ===\n") | |
f.write(f"Модель: {model_name}\n") | |
best_config = None | |
best_metric = float("-inf") | |
best_metrics = {} | |
combo_id = 0 | |
for combo in product(*(param_grid[param] for param in all_param_names)): | |
combo_id += 1 | |
param_combo = dict(zip(all_param_names, combo)) | |
config = copy.deepcopy(base_config) | |
for k, v in param_combo.items(): | |
setattr(config, k, v) | |
logging.info(f"\n[Комбинация #{combo_id}] {param_combo}") | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
csv_filename = f"{csv_prefix}_{model_name}_combo{combo_id}_{timestamp}.csv" if csv_prefix else None | |
dev_mean, dev_metrics, test_metrics = train_fn( | |
config, | |
train_loader, | |
dev_loader, | |
test_loader, | |
metrics_csv_path=csv_filename | |
) | |
is_better = dev_mean > best_metric | |
box_text = format_result_box_dual( | |
step_num=combo_id, | |
param_name=" + ".join(all_param_names), | |
candidate=str(combo), | |
fixed_params={}, | |
dev_metrics=dev_metrics, | |
test_metrics=test_metrics, | |
is_best=is_better | |
) | |
with open(overrides_file, "a", encoding="utf-8") as f: | |
f.write("\n" + box_text + "\n") | |
_log_dataset_metrics(dev_metrics, overrides_file, label="dev") | |
_log_dataset_metrics(test_metrics, overrides_file, label="test") | |
if is_better: | |
best_metric = dev_mean | |
best_config = param_combo | |
best_metrics = dev_metrics | |
with open(overrides_file, "a", encoding="utf-8") as f: | |
f.write("\n=== Лучшая комбинация (Dev-based) ===\n") | |
for k, v in best_config.items(): | |
f.write(f"{k} = {v}\n") | |
logging.info("Полный перебор завершён! Лучшие параметры выбраны.") | |
return best_metric, best_config, best_metrics | |
def _compute_combined_avg(dev_metrics): | |
if "by_dataset" not in dev_metrics: | |
return None | |
values = [] | |
for entry in dev_metrics["by_dataset"]: | |
for key in ["uar", "war", "mf1", "wf1"]: | |
if key in entry: | |
values.append(entry[key]) | |
return float(np.mean(values)) if values else None | |
def _log_dataset_metrics(metrics, file_path, label="dev"): | |
if "by_dataset" not in metrics: | |
return | |
with open(file_path, "a", encoding="utf-8") as f: | |
f.write(f"\n>>> Подробные метрики по каждому датасету ({label}):\n") | |
for ds in metrics["by_dataset"]: | |
name = ds.get("name", "unknown") | |
f.write(f" - {name}:\n") | |
for k in ["loss", "uar", "war", "mf1", "wf1", "mean"]: | |
if k in ds: | |
f.write(f" {k.upper():4} = {ds[k]:.4f}\n") | |
f.write(f"<<< Конец подробных метрик ({label})\n") | |