FelixPhilip's picture
updated the gradio ui
274fa3d
raw
history blame
2.32 kB
import os
import gradio as gr
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from Oracle.deepfundingoracle import prepare_dataset, train_predict_weight, create_submission_csv, SmolLM
from Oracle.DataSmolAgent import DataSmolAgent
def pipeline(upload):
# Load and clean/extract features via SmolAgents
df_raw = pd.read_csv(upload.name)
df_features = DataSmolAgent(df_raw).run(
prompt="Clean and extract features from the uploaded data", output_csv=False)
# Save preprocessed features and run dataset preparation on them
processed_path = "processed_input.csv"
pd.DataFrame(df_features).to_csv(processed_path, index=False)
df_prepared = prepare_dataset(processed_path)
# Assign base weights and predict final weights
df_results = train_predict_weight(df_prepared)
# Create submission CSV
csv_path = create_submission_csv(df_results, "submission.csv")
# Build dependency graph
G = nx.DiGraph()
for _, row in df_results.iterrows():
G.add_edge(row["parent"], row["repo"], weight=row["final_weight"])
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G)
weights = [G[u][v]["weight"] for u, v in G.edges()]
nx.draw(G, pos, with_labels=True, node_size=500, node_color="lightblue",
edge_color=weights, edge_cmap=plt.get_cmap('viridis'), width=2)
plt.savefig("graph.png")
# Generate explanation via SmolLM
explanation = SmolLM().predict(
"Explain the dependency graph and weight assignments for the dataset.")
# Return results
return (df_results.head().to_dict("records"), csv_path, "graph.png", explanation)
iface = gr.Interface(
fn=pipeline,
inputs=gr.File(label="Upload CSV", type="filepath"),
outputs=[
gr.Dataframe(label="Preview of Results"),
gr.File(label="Download CSV"),
gr.Image(label="Dependency Graph"),
gr.Textbox(label="Explanation")
],
title="DeepFunding Oracle",
description=(
"Upload a CSV to extract features, assign base weights via LLama, predict final weights with RandomForest, "
"and visualize the dependency graph with explanations."
)
)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
iface.launch(server_name="0.0.0.0", server_port=port)