MadGuard / app.py
Priti0210's picture
Removed tutorial.mov to meet Hugging Face size limits
009ec65
raw
history blame
10.9 kB
import gradio as gr
import nltk
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from typing import Tuple, Optional
from visuals.score_card import render_score_card
from visuals.layout import (
render_page_header,
render_core_reference,
render_pipeline,
render_pipeline_graph,
render_pipeline_warning,
render_strategy_alignment,
)
# Download tokenizer if not already available
try:
nltk.download("punkt", quiet=True)
except Exception as e:
print(f"Error downloading NLTK data: {e}")
# Load embedding model
model = SentenceTransformer("all-MiniLM-L6-v2")
# Global state to store uploaded DataFrame
uploaded_df = {}
# --- Core Metrics ---
def calculate_ttr(text: str) -> float:
words = text.split()
unique_words = set(words)
return len(unique_words) / len(words) if words else 0.0
def calculate_similarity(text1: str, text2: str) -> float:
embeddings = model.encode([text1, text2])
return cosine_similarity([embeddings[0]], [embeddings[1]])[0][0]
def calculate_mad_score(ttr: float, similarity: float) -> float:
return 0.3 * (1 - ttr) + 0.7 * similarity
def get_risk_level(mad_score: float) -> str:
if mad_score > 0.7:
return "High"
elif 0.4 <= mad_score <= 0.7:
return "Medium"
return "Low"
# --- Data Processing ---
def process_data(file_obj, model_col: str, train_col: str, data_source: str) -> Tuple[
Optional[str],
Optional[bytes],
Optional[str],
Optional[str],
Optional[float],
Optional[float],
Optional[float],
]:
try:
if not file_obj:
return "Error: No file uploaded.", None, None, None, None, None, None
df = uploaded_df.get("data")
if df is None:
return "Error: File not yet processed.", None, None, None, None, None, None
if model_col not in df.columns or train_col not in df.columns:
return (
"Error: Selected columns not found in the file.",
None,
None,
None,
None,
None,
None,
)
output_text = " ".join(df[model_col].astype(str))
train_text = " ".join(df[train_col].astype(str))
ttr_output = calculate_ttr(output_text)
ttr_train = calculate_ttr(train_text)
similarity = calculate_similarity(output_text, train_text)
mad_score = calculate_mad_score(ttr_output, similarity)
risk_level = get_risk_level(mad_score)
summary, details, explanation = render_score_card(
ttr_output, ttr_train, similarity, mad_score, risk_level
)
evaluation_markdown = summary + details + explanation
return (
None,
render_pipeline_graph(data_source),
df.head().to_markdown(index=False, numalign="left", stralign="left"),
evaluation_markdown,
ttr_output,
ttr_train,
similarity,
)
except Exception as e:
return f"An error occurred: {str(e)}", None, None, None, None, None, None
# --- Helpers ---
def update_dropdowns(file_obj) -> Tuple[gr.Dropdown, gr.Dropdown, str]:
global uploaded_df
if not file_obj:
uploaded_df["data"] = None
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
"No file uploaded.",
)
try:
file_name = getattr(file_obj, "name", "")
if file_name.endswith(".csv"):
df = pd.read_csv(file_obj)
elif file_name.endswith(".json"):
df = pd.read_json(file_obj)
else:
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
"Invalid file type.",
)
uploaded_df["data"] = df
preview = df.head().to_markdown(index=False, numalign="left", stralign="left")
return (
gr.update(choices=df.columns.tolist(), value=None),
gr.update(choices=df.columns.tolist(), value=None),
preview,
)
except Exception as e:
return (
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
f"Error reading file: {e}",
)
def clear_all_fields():
uploaded_df.clear()
return (
None,
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
"",
"",
"",
None,
None,
None,
render_pipeline_graph("Synthetic Generated Data"),
)
# --- Interface ---
def main_interface():
css = """
.gradio-container {
background: linear-gradient(-45deg, #e0f7fa, #e1f5fe, #f1f8e9, #fff3e0);
background-size: 400% 400%;
animation: oceanWaves 20s ease infinite;
}
@keyframes oceanWaves {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
"""
with gr.Blocks(css=css, title="MADGuard AI Explorer") as interface:
gr.HTML(render_page_header())
gr.HTML(
"""
<div style="text-align:center; margin-bottom: 20px;">
<h3>πŸ“½οΈ How to Use MADGuard AI Explorer</h3>
<iframe width="560" height="315" src="https://www.youtube.com/embed/qjMwvaBXQeY"
title="Tutorial Video" frameborder="0"
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture"
allowfullscreen></iframe>
</div>
"""
)
gr.Markdown(
"""
> 🧠 **MADGuard AI Explorer** helps simulate feedback loops in RAG pipelines and detect **Model Autophagy Disorder (MAD)**.
- Compare **real vs. synthetic input effects**
- Visualize the data flow
- Upload your `.csv` or `.json` data
- Get diagnostics based on lexical diversity and semantic similarity
"""
)
with gr.Accordion("πŸ“š Research Reference", open=False):
gr.HTML(render_core_reference())
gr.Markdown("## 1. Pipeline Simulation")
data_source, description = render_pipeline(default="Synthetic Generated Data")
gr.HTML(description)
pipeline_output = gr.Image(type="filepath", label="Pipeline Graph")
warning_output = gr.HTML()
data_source.change(
fn=render_pipeline_warning, inputs=data_source, outputs=warning_output
)
data_source.change(
fn=render_pipeline_graph, inputs=data_source, outputs=pipeline_output
)
interface.load(
fn=render_pipeline_graph, inputs=[data_source], outputs=[pipeline_output]
)
gr.Markdown("## 2. Upload CSV or JSON File")
file_input = gr.File(
file_types=[".csv", ".json"], label="Upload a CSV or JSON file"
)
clear_btn = gr.Button("🧹 Clear All")
gr.Markdown(
"""
πŸ“ **Note:**
- **Model Output Column**: Model-generated responses/completions.
- **Training Data Column**: Candidate future training input.
"""
)
with gr.Row():
model_col_input = gr.Dropdown(
choices=[], label="Select column for model output", interactive=True
)
train_col_input = gr.Dropdown(
choices=[],
label="Select column for future training data",
interactive=True,
)
file_preview = gr.Markdown(label="πŸ“„ File Preview")
output_markdown = gr.Markdown(label="πŸ” Evaluation Summary")
with gr.Accordion("πŸ“‹ Research-Based Strategy Alignment", open=False):
gr.HTML(render_strategy_alignment())
with gr.Row():
ttr_output_metric = gr.Number(label="Lexical Diversity (Output)")
ttr_train_metric = gr.Number(label="Lexical Diversity (Training Set)")
similarity_metric = gr.Number(label="Semantic Similarity (Cosine)")
def handle_file_upload(file_obj, data_source_val):
dropdowns = update_dropdowns(file_obj)
graph = render_pipeline_graph(data_source_val)
return *dropdowns, graph
file_input.change(
fn=handle_file_upload,
inputs=[file_input, data_source],
outputs=[model_col_input, train_col_input, file_preview, pipeline_output],
)
def process_and_generate(
file_obj, model_col_val, train_col_val, data_source_val
):
error, graph, preview, markdown, ttr_out, ttr_tr, sim = process_data(
file_obj, model_col_val, train_col_val, data_source_val
)
if error:
return error, graph, warning_output, preview, None, None, None, None
return (
"",
graph,
render_pipeline_warning(data_source_val),
preview,
markdown,
ttr_out,
ttr_tr,
sim,
)
inputs = [file_input, model_col_input, train_col_input, data_source]
outputs = [
gr.Markdown(label="⚠️ Error Message"),
pipeline_output,
warning_output,
file_preview,
output_markdown,
ttr_output_metric,
ttr_train_metric,
similarity_metric,
]
clear_btn.click(
fn=clear_all_fields,
inputs=[],
outputs=[
file_input,
model_col_input,
train_col_input,
file_preview,
output_markdown,
warning_output,
ttr_output_metric,
ttr_train_metric,
similarity_metric,
pipeline_output,
],
)
for input_component in inputs:
input_component.change(
fn=process_and_generate, inputs=inputs, outputs=outputs
)
gr.Markdown("---")
gr.Markdown(
"""
**Pro version coming soon:**
- Bulk CSV uploads
- Trend visualizations
- One-click export of audit reports
[πŸ“© Join the waitlist](https://docs.google.com/forms/d/e/1FAIpQLSfAPPC_Gm7DQElQSWGSnoB6T5hMxb_rXSu48OC8E6TNGZuKgQ/viewform?usp=sharing&ouid=118007615320536574300)
"""
)
return interface
if __name__ == "__main__":
interface = main_interface()
interface.launch(server_name="0.0.0.0", server_port=7860)