Spaces:
Runtime error
Runtime error
import plotly.express as px | |
import gradio as gr | |
import plotly.graph_objects as go | |
import seaborn as sns | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from matplotlib.ticker import PercentFormatter | |
def plot_wow_retention_by_type(wow_retention): | |
wow_retention["week"] = pd.to_datetime(wow_retention["week"], format="%b-%d-%Y") | |
wow_retention = wow_retention.sort_values(["trader_type", "week"]) | |
fig = px.line( | |
wow_retention, | |
x="week", | |
y="retention_rate", | |
color="trader_type", | |
markers=True, | |
title="Weekly Retention Rate by Trader Type", | |
labels={ | |
"week": "Week", | |
"retention_rate": "Retention Rate (%)", | |
"trader_type": "Trader Type", | |
}, | |
color_discrete_sequence=["purple", "goldenrod", "green"], | |
) | |
fig.update_layout( | |
hovermode="x unified", | |
legend=dict( | |
yanchor="middle", | |
y=0.5, | |
xanchor="left", | |
x=0.99, | |
orientation="v", | |
), | |
yaxis=dict( | |
ticksuffix="%", | |
range=[ | |
0, | |
max(wow_retention["retention_rate"]) * 1.1, | |
], # Add 10% padding to y-axis | |
), | |
xaxis=dict(tickformat="%Y-%m-%d"), | |
margin=dict(r=200), # Adjusted margins | |
width=600, # Set explicit width | |
height=500, # Set explicit height | |
) | |
# Add hover template | |
fig.update_traces( | |
hovertemplate="<b>%{y:.1f}%</b><br>Week: %{x|%Y-%m-%d}<extra></extra>" | |
) | |
return gr.Plot( | |
value=fig, | |
) | |
def plot_cohort_retention_heatmap(retention_matrix: pd.DataFrame, cmap: str): | |
# Create a copy of the matrix to avoid modifying the original | |
retention_matrix = retention_matrix.copy() | |
# Convert index to datetime and format to date string | |
retention_matrix.index = pd.to_datetime(retention_matrix.index).strftime("%a-%b %d") | |
# Create figure and axes with specified size | |
plt.figure(figsize=(12, 8)) | |
# Create mask for NaN values | |
mask = retention_matrix.isna() | |
# Create heatmap | |
ax = sns.heatmap( | |
data=retention_matrix, | |
annot=True, # Show numbers in cells | |
fmt=".1f", # Format numbers to 1 decimal place | |
cmap=cmap, # Yellow to Orange to Red color scheme | |
vmin=0, | |
vmax=100, | |
center=50, | |
cbar_kws={"label": "Retention Rate (%)", "format": PercentFormatter()}, | |
mask=mask, | |
annot_kws={"size": 8}, | |
) | |
# Customize the plot | |
plt.title("Cohort Retention Analysis", pad=20, size=14) | |
plt.xlabel("Weeks Since First Activiy", size=12) | |
plt.ylabel("Cohort First Day of the Week", size=12) | |
# Format week numbers on x-axis | |
x_labels = [f"Week {i}" for i in retention_matrix.columns] | |
ax.set_xticklabels(x_labels, rotation=45, ha="right") | |
# Set y-axis labels rotation | |
plt.yticks(rotation=0) | |
# Add gridlines | |
ax.set_axisbelow(True) | |
# Adjust layout to prevent label cutoff | |
plt.tight_layout() | |
cohort_fig = ax.get_figure() | |
return gr.Plot(value=cohort_fig) | |