Spaces:
Sleeping
Sleeping
File size: 6,690 Bytes
513a1f2 3410966 274fa3d 3388ab8 4cdc4e2 a1599a6 bf9b592 513a1f2 4cdc4e2 bf9b592 4cdc4e2 bf9b592 3410966 513a1f2 bf9b592 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import os
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import time
import io
from PIL import Image
import logging
# Import the functions from deepfundingoracle
from Oracle.deepfundingoracle import prepare_dataset, train_predict_weight, create_submission_csv, \
normalize_and_clip_weights
# Configure logging
logging.basicConfig(level=logging.INFO)
def analyze_file(file, progress=gr.Progress(track_tqdm=True)):
"""
Analyzes the uploaded file and generates results.
"""
start_time = time.time()
try:
# Step 1: Prepare dataset
progress(0, desc="Preparing dataset...")
df = prepare_dataset(file.name)
# Step 2: Train model and predict weights
progress(0.3, desc="Training model and predicting weights...")
df = train_predict_weight(df)
# Step 3: Normalize weights
progress(0.5, desc="Normalizing weights...")
df = normalize_and_clip_weights(df)
# Step 4: Save results
progress(0.6, desc="Saving results to CSV...")
output_filename = "submission.csv"
create_submission_csv(df, output_filename)
# Step 5: Generate visualizations
progress(0.8, desc="Generating graphs...")
# Feature distribution plot
dist_fig = plt.figure(figsize=(15, 10))
numeric_cols = df.select_dtypes(include=[np.number]).columns
plot_cols = [col for col in numeric_cols if
col in ['stars', 'forks', 'watchers', 'contributors', 'pulls', 'final_weight']]
if plot_cols:
df[plot_cols].hist(bins=20, figsize=(15, 10), color="skyblue", edgecolor="black")
plt.suptitle("Feature Distributions", fontsize=16)
plt.tight_layout()
dist_buf = io.BytesIO()
plt.savefig(dist_buf, format='png', dpi=100, bbox_inches='tight')
dist_buf.seek(0)
plt.close(dist_fig)
dist_img = Image.open(dist_buf)
# Correlation matrix plot
corr_fig = plt.figure(figsize=(12, 8))
if len(plot_cols) > 1:
correlation_matrix = df[plot_cols].corr()
sns.heatmap(correlation_matrix, annot=True, cmap="coolwarm", fmt=".2f", linewidths=0.5)
plt.title("Feature Correlation Matrix", fontsize=16)
corr_buf = io.BytesIO()
plt.savefig(corr_buf, format='png', dpi=100, bbox_inches='tight')
corr_buf.seek(0)
plt.close(corr_fig)
corr_img = Image.open(corr_buf)
# Prepare preview
progress(1, desc="Done!")
elapsed = time.time() - start_time
# Create a summary preview
summary_df = df[['repo', 'parent', 'final_weight']].head(10)
preview = f"Top 10 Results:\n{summary_df.to_string(index=False)}\n\nTotal repositories analyzed: {len(df)}"
# Return the path to the generated file for automatic download
return (
preview,
output_filename, # This will trigger automatic download
dist_img,
corr_img,
f"β
Analysis completed successfully in {elapsed:.2f} seconds.\nπ₯ Results file ready for download!"
)
except Exception as e:
logging.error(f"Error during analysis: {str(e)}")
elapsed = time.time() - start_time
error_msg = f"β Error: {str(e)}\nTime elapsed: {elapsed:.2f} seconds"
# Return empty images and error message
empty_img = Image.new('RGB', (800, 600), color='white')
return error_msg, None, empty_img, empty_img, error_msg
# Custom CSS for better styling
custom_css = """
.download-button {
background-color: #4CAF50 !important;
color: white !important;
font-weight: bold !important;
}
.status-box {
font-family: monospace;
padding: 10px;
border-radius: 5px;
}
"""
# Create Gradio interface with automatic download
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as iface:
gr.Markdown("""
# π DeepFunding Oracle
Upload a CSV file containing repository dependencies with 'repo' and 'parent' columns.
The system will:
1. **Fetch** GitHub metrics for each repository
2. **Generate** importance weights using AI
3. **Train** a model to predict final contribution weights
4. **Normalize** weights so they sum to 1 per parent
β οΈ **Note**: Set `GITHUB_API_TOKEN` environment variable for better API rate limits.
""")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload CSV File",
file_types=[".csv"],
elem_id="file-upload"
)
analyze_btn = gr.Button("π Analyze", variant="primary", size="lg")
with gr.Column(scale=2):
status_output = gr.Textbox(
label="Status",
lines=3,
elem_classes="status-box"
)
with gr.Row():
preview_output = gr.Textbox(
label="Preview of Results",
lines=15,
show_copy_button=True
)
with gr.Row():
download_output = gr.File(
label="π₯ Download Results CSV",
visible=True,
elem_classes="download-button"
)
with gr.Row():
with gr.Column():
dist_plot = gr.Image(label="Feature Distributions")
with gr.Column():
corr_plot = gr.Image(label="Feature Correlation Matrix")
# JavaScript for automatic download
download_js = """
() => {
setTimeout(() => {
const downloadButton = document.querySelector('.download-button a');
if (downloadButton) {
downloadButton.click();
}
}, 500);
}
"""
# Set up the event handler
analyze_btn.click(
fn=analyze_file,
inputs=[file_input],
outputs=[preview_output, download_output, dist_plot, corr_plot, status_output]
).then(
fn=None,
inputs=None,
outputs=None,
_js=download_js # This triggers automatic download
)
# Add example usage
gr.Examples(
examples=[["example_dependencies.csv"]], # Add your example file here if you have one
inputs=file_input,
outputs=[preview_output, download_output, dist_plot, corr_plot, status_output],
fn=analyze_file,
cache_examples=False,
)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
iface.launch(
server_name="0.0.0.0",
server_port=port,
share=False,
show_error=True
) |