File size: 1,944 Bytes
99399ee
 
88c98d4
99399ee
 
88c98d4
99399ee
 
 
88c98d4
99399ee
 
 
 
 
88c98d4
 
99399ee
 
 
88c98d4
99399ee
 
 
 
88c98d4
99399ee
 
 
 
 
 
88c98d4
 
 
99399ee
 
 
 
 
88c98d4
99399ee
88c98d4
 
 
99399ee
 
 
88c98d4
99399ee
 
88c98d4
 
99399ee
88c98d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import matplotlib.pyplot as plt
import pandas as pd

from .utils import undo_hyperlink


def plot_avg_correlation(df1, df2):
    """
    Plots the "average" column for each unique model that appears in both dataframes.

    Parameters:
    - df1: pandas DataFrame containing columns "model" and "average".
    - df2: pandas DataFrame containing columns "model" and "average".
    """
    # Identify the unique models that appear in both DataFrames
    common_models = pd.Series(list(set(df1["model"]) & set(df2["model"])))

    # Set up the plot
    plt.figure(figsize=(13, 6), constrained_layout=True)

    # axes from 0 to 1 for x and y
    plt.xlim(0.475, 0.8)
    plt.ylim(0.475, 0.8)

    # larger font (16)
    plt.rcParams.update({"font.size": 12, "axes.labelsize": 14, "axes.titlesize": 14})
    # plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1)
    # plt.tight_layout()
    # plt.margins(0,0)

    for model in common_models:
        # Filter data for the current model
        df1_model_data = df1[df1["model"] == model]["average"].values
        df2_model_data = df2[df2["model"] == model]["average"].values

        # Plotting
        plt.scatter(df1_model_data, df2_model_data, label=model)
        m_name = undo_hyperlink(model)
        if m_name == "No text found":
            m_name = "Random"
        # Add text above each point like
        # plt.text(x[i] + 0.1, y[i] + 0.1, label, ha='left', va='bottom')
        plt.text(
            df1_model_data - 0.005, df2_model_data, m_name, horizontalalignment="right", verticalalignment="center"
        )

    # add correlation line to scatter plot
    # first, compute correlation
    corr = df1["average"].corr(df2["average"])
    # add correlation line based on corr

    plt.xlabel("HERM Eval. Set Avg.", fontsize=16)
    plt.ylabel("Pref. Test Sets Avg.", fontsize=16)
    # plt.legend(title='Model', bbox_to_anchor=(1.05, 1), loc='upper left')
    return plt