FelixPhilip commited on
Commit
4cdc4e2
·
1 Parent(s): 51b48a8

updated gradio to call the functions directly

Browse files
Files changed (1) hide show
  1. app.py +36 -40
app.py CHANGED
@@ -1,54 +1,50 @@
1
  import os
2
  import gradio as gr
 
3
  import pandas as pd
4
- import networkx as nx
5
  import matplotlib.pyplot as plt
6
- from Oracle.deepfundingoracle import prepare_dataset, train_predict_weight, create_submission_csv, SmolLM
7
- from Oracle.DataSmolAgent import DataSmolAgent
8
 
9
- def pipeline(upload):
10
- # Load and clean/extract features via SmolAgents
11
- df_raw = pd.read_csv(upload.name)
12
- df_features = DataSmolAgent(df_raw).run(
13
- prompt="Clean and extract features from the uploaded data", output_csv=False)
14
- # Save preprocessed features and run dataset preparation on them
15
- processed_path = "processed_input.csv"
16
- pd.DataFrame(df_features).to_csv(processed_path, index=False)
17
- df_prepared = prepare_dataset(processed_path)
18
- # Assign base weights and predict final weights
19
- df_results = train_predict_weight(df_prepared)
20
- # Create submission CSV
21
- csv_path = create_submission_csv(df_results, "submission.csv")
22
- # Build dependency graph
23
- G = nx.DiGraph()
24
- for _, row in df_results.iterrows():
25
- G.add_edge(row["parent"], row["repo"], weight=row["final_weight"])
26
- plt.figure(figsize=(10, 8))
27
- pos = nx.spring_layout(G)
28
- weights = [G[u][v]["weight"] for u, v in G.edges()]
29
- nx.draw(G, pos, with_labels=True, node_size=500, node_color="lightblue",
30
- edge_color=weights, edge_cmap=plt.get_cmap('viridis'), width=2)
31
- plt.savefig("graph.png")
32
- # Generate explanation via SmolLM
33
- explanation = SmolLM().predict(
34
- "Explain the dependency graph and weight assignments for the dataset.")
35
- # Return results
36
- return (df_results.head().to_dict("records"), csv_path, "graph.png", explanation)
37
 
38
  iface = gr.Interface(
39
- fn=pipeline,
40
- inputs=gr.File(label="Upload CSV", type="filepath"),
41
  outputs=[
42
- gr.Dataframe(label="Preview of Results"),
43
  gr.File(label="Download CSV"),
44
- gr.Image(label="Dependency Graph"),
45
- gr.Textbox(label="Explanation")
46
  ],
47
  title="DeepFunding Oracle",
48
- description=(
49
- "Upload a CSV to extract features, assign base weights via LLama, predict final weights with RandomForest, "
50
- "and visualize the dependency graph with explanations."
51
- )
52
  )
53
 
54
  if __name__ == "__main__":
 
1
  import os
2
  import gradio as gr
3
+ from Oracle.deepfundingoracle import prepare_dataset, train_predict_weight, create_submission_csv
4
  import pandas as pd
 
5
  import matplotlib.pyplot as plt
6
+ import time
7
+ import io
8
 
9
+ def analyze_file(file, progress=gr.Progress(track_tqdm=True)):
10
+ start_time = time.time()
11
+ progress(0, desc="Preparing dataset...")
12
+ df = prepare_dataset(file.name)
13
+ progress(0.3, desc="Predicting weights...")
14
+ df = train_predict_weight(df)
15
+ progress(0.6, desc="Saving results to CSV...")
16
+ csv_path = create_submission_csv(df, "submission.csv")
17
+ progress(0.8, desc="Generating graph...")
18
+ # Example: plot histogram of a column if exists
19
+ fig, ax = plt.subplots()
20
+ if 'final_weight' in df.columns:
21
+ df['final_weight'].hist(ax=ax)
22
+ ax.set_title('Distribution of Final Weights')
23
+ ax.set_xlabel('Final Weight')
24
+ ax.set_ylabel('Count')
25
+ else:
26
+ ax.text(0.5, 0.5, 'No final_weight column to plot', ha='center')
27
+ buf = io.BytesIO()
28
+ plt.savefig(buf, format='png')
29
+ buf.seek(0)
30
+ plt.close(fig)
31
+ progress(1, desc="Done!")
32
+ elapsed = time.time() - start_time
33
+ preview = df.head().to_csv(index=False)
34
+ return preview, csv_path, buf, f"Analysis completed in {elapsed:.2f} seconds."
 
 
35
 
36
  iface = gr.Interface(
37
+ fn=analyze_file,
38
+ inputs=gr.File(label="Upload CSV"),
39
  outputs=[
40
+ gr.Textbox(label="Preview of Results"),
41
  gr.File(label="Download CSV"),
42
+ gr.Image(label="Analysis Graph"),
43
+ gr.Textbox(label="Status/Timing Info")
44
  ],
45
  title="DeepFunding Oracle",
46
+ description="Upload a CSV of repo-parent relationships; see analysis progress, get a graph, and download results as CSV.",
47
+ allow_flagging="never"
 
 
48
  )
49
 
50
  if __name__ == "__main__":