Spaces:
Sleeping
Sleeping
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix | |
import numpy as np | |
import json | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# Used during training | |
def compute_metrics(pred): | |
labels = pred.label_ids | |
preds = np.argmax(pred.predictions, axis=1) | |
return { | |
"accuracy": accuracy_score(labels, preds), | |
"f1": f1_score(labels, preds, average="weighted") | |
} | |
# Save classification report | |
def save_metrics(y_true, y_pred, label_names, out_path="outputs/metrics/report.json"): | |
report = classification_report(y_true, y_pred, target_names=label_names, output_dict=True) | |
with open(out_path, "w") as f: | |
json.dump(report, f, indent=4) | |
# Save confusion matrix as image | |
def save_confusion_matrix(y_true, y_pred, label_names, out_path="outputs/metrics/confusion_matrix.png"): | |
cm = confusion_matrix(y_true, y_pred) | |
plt.figure(figsize=(10, 8)) | |
sns.heatmap(cm, annot=True, fmt='d', xticklabels=label_names, yticklabels=label_names) | |
plt.xlabel("Predicted") | |
plt.ylabel("Actual") | |
plt.title("Confusion Matrix") | |
plt.tight_layout() | |
plt.savefig(out_path) | |