Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from transformers import pipeline | |
from typing import Tuple, Optional | |
import io | |
import fitz # PyMuPDF | |
from PIL import Image | |
import pandas as pd | |
import uvicorn | |
from docx import Document | |
from pptx import Presentation | |
import pytesseract | |
import logging | |
import re | |
from slowapi import Limiter | |
from slowapi.util import get_remote_address | |
from slowapi.errors import RateLimitExceeded | |
from slowapi.middleware import SlowAPIMiddleware | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import tempfile | |
import base64 | |
from io import BytesIO | |
from pydantic import BaseModel | |
import traceback | |
import ast | |
from fastapi.responses import HTMLResponse | |
from fastapi import Request | |
from pathlib import Path | |
from fastapi.staticfiles import StaticFiles | |
# Initialize rate limiter | |
limiter = Limiter(key_func=get_remote_address) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Serve static files (frontend) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
def home (): | |
with open("static/indexAI.html","r") as file : | |
return file.read() | |
# Apply rate limiting middleware | |
app.state.limiter = limiter | |
app.add_middleware(SlowAPIMiddleware) | |
# CORS Configuration | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Constants | |
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
SUPPORTED_FILE_TYPES = { | |
"docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png" | |
} | |
# Model caching | |
summarizer = None | |
qa_model = None | |
image_captioner = None | |
def get_summarizer(): | |
global summarizer | |
if summarizer is None: | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
return summarizer | |
def get_qa_model(): | |
global qa_model | |
if qa_model is None: | |
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
return qa_model | |
def get_image_captioner(): | |
global image_captioner | |
if image_captioner is None: | |
image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") | |
return image_captioner | |
async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]: | |
"""Validate and process uploaded file with special handling for each type""" | |
if not file.filename: | |
raise HTTPException(400, "No filename provided") | |
file_ext = file.filename.split('.')[-1].lower() | |
if file_ext not in SUPPORTED_FILE_TYPES: | |
raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}") | |
content = await file.read() | |
if len(content) > MAX_FILE_SIZE: | |
raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB") | |
# Special validation for PDFs | |
if file_ext == "pdf": | |
try: | |
with fitz.open(stream=content, filetype="pdf") as doc: | |
if doc.is_encrypted: | |
if not doc.authenticate(""): | |
raise ValueError("Encrypted PDF - cannot extract text") | |
if len(doc) > 50: | |
raise ValueError("PDF too large (max 50 pages)") | |
except Exception as e: | |
logger.error(f"PDF validation failed: {str(e)}") | |
raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}") | |
await file.seek(0) # Reset file pointer for processing | |
return file_ext, content | |
def extract_text(content: bytes, file_ext: str) -> str: | |
"""Extract text from various file formats with enhanced support""" | |
try: | |
if file_ext == "docx": | |
doc = Document(io.BytesIO(content)) | |
return "\n".join(para.text for para in doc.paragraphs if para.text.strip()) | |
elif file_ext in {"xlsx", "xls"}: | |
df = pd.read_excel(io.BytesIO(content), sheet_name=None) | |
all_text = [] | |
for sheet_name, sheet_data in df.items(): | |
sheet_text = [] | |
for column in sheet_data.columns: | |
sheet_text.extend(sheet_data[column].dropna().astype(str).tolist()) | |
all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text)) | |
return "\n\n".join(all_text) | |
elif file_ext == "pptx": | |
ppt = Presentation(io.BytesIO(content)) | |
text = [] | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text") and shape.text.strip(): | |
text.append(shape.text) | |
return "\n".join(text) | |
elif file_ext == "pdf": | |
pdf = fitz.open(stream=content, filetype="pdf") | |
return "\n".join(page.get_text("text") for page in pdf) | |
elif file_ext in {"jpg", "jpeg", "png"}: | |
# First try OCR | |
try: | |
image = Image.open(io.BytesIO(content)) | |
text = pytesseract.image_to_string(image, config='--psm 6') | |
if text.strip(): | |
return text | |
# If OCR fails, try image captioning | |
captioner = get_image_captioner() | |
result = captioner(image) | |
return result[0]['generated_text'] | |
except Exception as img_e: | |
logger.error(f"Image processing failed: {str(img_e)}") | |
raise ValueError("Could not extract text or caption from image") | |
except Exception as e: | |
logger.error(f"Text extraction failed for {file_ext}: {str(e)}") | |
raise HTTPException(422, f"Failed to extract text from {file_ext} file") | |
# Visualization Models | |
class VisualizationRequest(BaseModel): | |
chart_type: str | |
x_column: Optional[str] = None | |
y_column: Optional[str] = None | |
hue_column: Optional[str] = None | |
title: Optional[str] = None | |
x_label: Optional[str] = None | |
y_label: Optional[str] = None | |
style: str = "seaborn-v0_8" # Updated default | |
filters: Optional[dict] = None | |
class NaturalLanguageRequest(BaseModel): | |
prompt: str | |
style: str = "seaborn-v0_8" | |
def validate_matplotlib_style(style: str) -> str: | |
"""Validate and return a valid matplotlib style""" | |
available_styles = plt.style.available | |
# Map legacy style names to current ones | |
style_mapping = { | |
'seaborn': 'seaborn-v0_8', | |
'seaborn-white': 'seaborn-v0_8-white', | |
'seaborn-dark': 'seaborn-v0_8-dark', | |
# Add other legacy mappings if needed | |
} | |
# Check if it's a legacy name we can map | |
if style in style_mapping: | |
return style_mapping[style] | |
# Check if it's a valid current style | |
if style in available_styles: | |
return style | |
logger.warning(f"Invalid style '{style}'. Available styles: {available_styles}") | |
return "seaborn-v0_8" # Default fallback to current seaborn style | |
def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str: | |
"""Generate Python code for visualization based on request parameters""" | |
# Validate style | |
valid_style = validate_matplotlib_style(request.style) | |
code_lines = [ | |
"import matplotlib.pyplot as plt", | |
"import seaborn as sns", | |
"import pandas as pd", | |
"", | |
"# Data preparation", | |
f"df = pd.DataFrame({df.to_dict(orient='list')})", | |
] | |
# Apply filters if specified | |
if request.filters: | |
filter_conditions = [] | |
for column, condition in request.filters.items(): | |
if isinstance(condition, dict): | |
if 'min' in condition and 'max' in condition: | |
filter_conditions.append(f"(df['{column}'] >= {condition['min']}) & (df['{column}'] <= {condition['max']})") | |
elif 'values' in condition: | |
values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']]) | |
filter_conditions.append(f"df['{column}'].isin([{values}])") | |
else: | |
filter_conditions.append(f"df['{column}'] == {repr(condition)}") | |
if filter_conditions: | |
code_lines.extend([ | |
"", | |
"# Apply filters", | |
f"df = df[{' & '.join(filter_conditions)}]" | |
]) | |
code_lines.extend([ | |
"", | |
"# Visualization", | |
f"plt.style.use('{valid_style}')", | |
f"plt.figure(figsize=(10, 6))" | |
]) | |
# Chart type specific code | |
if request.chart_type == "line": | |
if request.hue_column: | |
code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "bar": | |
if request.hue_column: | |
code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "scatter": | |
if request.hue_column: | |
code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "histogram": | |
code_lines.append(f"plt.hist(df['{request.x_column}'], bins=20)") | |
elif request.chart_type == "boxplot": | |
if request.hue_column: | |
code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"sns.boxplot(data=df, x='{request.x_column}', y='{request.y_column}')") | |
elif request.chart_type == "heatmap": | |
code_lines.append(f"corr = df.corr()") | |
code_lines.append(f"sns.heatmap(corr, annot=True, cmap='coolwarm')") | |
else: | |
raise ValueError(f"Unsupported chart type: {request.chart_type}") | |
# Add labels and title | |
if request.title: | |
code_lines.append(f"plt.title('{request.title}')") | |
if request.x_label: | |
code_lines.append(f"plt.xlabel('{request.x_label}')") | |
if request.y_label: | |
code_lines.append(f"plt.ylabel('{request.y_label}')") | |
code_lines.extend([ | |
"plt.tight_layout()", | |
"plt.show()" | |
]) | |
return "\n".join(code_lines) | |
def interpret_natural_language(prompt: str, df_columns: list) -> VisualizationRequest: | |
"""Convert natural language prompt to visualization parameters""" | |
prompt = prompt.lower() | |
# Determine chart type | |
chart_type = "bar" | |
if "line" in prompt: | |
chart_type = "line" | |
elif "scatter" in prompt: | |
chart_type = "scatter" | |
elif "histogram" in prompt: | |
chart_type = "histogram" | |
elif "box" in prompt: | |
chart_type = "boxplot" | |
elif "heatmap" in prompt or "correlation" in prompt: | |
chart_type = "heatmap" | |
# Try to detect columns | |
x_col = None | |
y_col = None | |
hue_col = None | |
for col in df_columns: | |
if col.lower() in prompt: | |
if not x_col: | |
x_col = col | |
elif not y_col: | |
y_col = col | |
else: | |
hue_col = col | |
# Default to first columns if not detected | |
if not x_col and len(df_columns) > 0: | |
x_col = df_columns[0] | |
if not y_col and len(df_columns) > 1: | |
y_col = df_columns[1] | |
return VisualizationRequest( | |
chart_type=chart_type, | |
x_column=x_col, | |
y_column=y_col, | |
hue_column=hue_col, | |
title="Generated from: " + prompt[:50] + ("..." if len(prompt) > 50 else ""), | |
style="seaborn-v0_8" # Updated default | |
) | |
async def summarize_document(request: Request, file: UploadFile = File(...)): | |
try: | |
file_ext, content = await process_uploaded_file(file) | |
text = extract_text(content, file_ext) | |
if not text.strip(): | |
raise HTTPException(400, "No extractable text found") | |
# Clean and chunk text | |
text = re.sub(r'\s+', ' ', text).strip() | |
chunks = [text[i:i+1000] for i in range(0, len(text), 1000)] | |
# Summarize each chunk | |
summarizer = get_summarizer() | |
summaries = [] | |
for chunk in chunks: | |
summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"] | |
summaries.append(summary) | |
return {"summary": " ".join(summaries)} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Summarization failed: {str(e)}") | |
raise HTTPException(500, "Document summarization failed") | |
async def question_answering( | |
request: Request, | |
file: UploadFile = File(...), | |
question: str = Form(...), | |
language: str = Form("fr") | |
): | |
try: | |
file_ext, content = await process_uploaded_file(file) | |
text = extract_text(content, file_ext) | |
if not text.strip(): | |
raise HTTPException(400, "No extractable text found") | |
# Clean and truncate text | |
text = re.sub(r'\s+', ' ', text).strip()[:5000] | |
# Theme detection | |
theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"] | |
if any(kw in question.lower() for kw in theme_keywords): | |
try: | |
summarizer = get_summarizer() | |
summary_output = summarizer( | |
text, | |
max_length=min(100, len(text)//4), | |
min_length=30, | |
do_sample=False, | |
truncation=True | |
) | |
theme = summary_output[0].get("summary_text", text[:200] + "...") | |
return { | |
"question": question, | |
"answer": f"Le document traite principalement de : {theme}", | |
"confidence": 0.95, | |
"language": language | |
} | |
except Exception: | |
theme = text[:200] + ("..." if len(text) > 200 else "") | |
return { | |
"question": question, | |
"answer": f"D'après le document : {theme}", | |
"confidence": 0.7, | |
"language": language, | |
"warning": "theme_summary_fallback" | |
} | |
# Standard QA | |
qa = get_qa_model() | |
result = qa(question=question, context=text[:3000]) | |
return { | |
"question": question, | |
"answer": result["answer"], | |
"confidence": result["score"], | |
"language": language | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"QA processing failed: {str(e)}") | |
raise HTTPException(500, detail=f"Analysis failed: {str(e)}") | |
async def visualize_with_code( | |
request: Request, | |
file: UploadFile = File(...), | |
chart_type: str = Form(...), | |
x_column: Optional[str] = Form(None), | |
y_column: Optional[str] = Form(None), | |
hue_column: Optional[str] = Form(None), | |
title: Optional[str] = Form(None), | |
x_label: Optional[str] = Form(None), | |
y_label: Optional[str] = Form(None), | |
style: str = Form("seaborn-v0_8"), # Updated default | |
filters: Optional[str] = Form(None) | |
): | |
try: | |
file_ext, content = await process_uploaded_file(file) | |
if file_ext not in {"xlsx", "xls"}: | |
raise HTTPException(400, "Visualization is only supported for Excel files") | |
df = pd.read_excel(io.BytesIO(content)) | |
if df.empty: | |
raise HTTPException(400, "The uploaded Excel file is empty") | |
# Convert filters from string to dictionary safely | |
filters_dict = None | |
if filters: | |
try: | |
filters_dict = ast.literal_eval(filters) | |
if not isinstance(filters_dict, dict): | |
raise ValueError() | |
except Exception: | |
raise HTTPException(400, "Invalid format for filters. Must be a valid dictionary string.") | |
viz_request = VisualizationRequest( | |
chart_type=chart_type, | |
x_column=x_column, | |
y_column=y_column, | |
hue_column=hue_column, | |
title=title, | |
x_label=x_label, | |
y_label=y_label, | |
style=style, | |
filters=filters_dict | |
) | |
code = generate_visualization_code(df, viz_request) | |
return {"code": code} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Visualization code generation failed: {str(e)}") | |
raise HTTPException(500, f"Visualization code generation failed: {str(e)}") | |
from fastapi.responses import FileResponse # Add this import at the top | |
async def visualize_with_natural_language( | |
request: Request, | |
file: UploadFile = File(...), | |
prompt: str = Form(...), | |
style: str = Form("seaborn-v0_8"), | |
return_type: str = Form("base64") # New parameter: "base64" or "file" | |
): | |
try: | |
# Validate file and process data (existing code) | |
file_ext, content = await process_uploaded_file(file) | |
if file_ext not in {"xlsx", "xls"}: | |
raise HTTPException(400, "Only Excel files are supported for visualization") | |
df = pd.read_excel(io.BytesIO(content)) | |
nl_request = NaturalLanguageRequest(prompt=prompt, style=style) | |
vis_request = interpret_natural_language(nl_request.prompt, df.columns.tolist()) | |
visualization_code = generate_visualization_code(df, vis_request) | |
# Generate the plot | |
plt.figure() | |
local_vars = {} | |
exec(visualization_code, globals(), local_vars) | |
# Save the plot to a temporary file | |
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') | |
plt.close() | |
# Handle response type | |
if return_type == "file": | |
# Return as downloadable file | |
return FileResponse( | |
temp_file.name, | |
media_type="image/png", | |
filename="visualization.png" | |
) | |
else: | |
# Return as Base64 (original behavior) | |
with open(temp_file.name, "rb") as f: | |
image_bytes = f.read() | |
image_base64 = base64.b64encode(image_bytes).decode('utf-8') | |
# Clean up the temp file | |
try: | |
os.unlink(temp_file.name) | |
except: | |
pass | |
return { | |
"status": "success", | |
"image": f"data:image/png;base64,{image_base64}", | |
"code": visualization_code, | |
"interpreted_parameters": vis_request.dict() | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Natural language visualization failed: {str(e)}\n{traceback.format_exc()}") | |
raise HTTPException(500, detail=f"Visualization failed: {str(e)}") | |
async def list_available_styles(request: Request): | |
"""List all available matplotlib styles""" | |
return {"available_styles": plt.style.available} | |
async def get_excel_columns( | |
request: Request, | |
file: UploadFile = File(...) | |
): | |
try: | |
file_ext, content = await process_uploaded_file(file) | |
if file_ext not in {"xlsx", "xls"}: | |
raise HTTPException(400, "Only Excel files are supported") | |
df = pd.read_excel(io.BytesIO(content)) | |
return { | |
"columns": list(df.columns), | |
"sample_data": df.head().to_dict(orient='records'), | |
"statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None | |
} | |
except Exception as e: | |
logger.error(f"Column extraction failed: {str(e)}") | |
raise HTTPException(500, detail="Failed to extract columns from Excel file") | |
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): | |
return JSONResponse( | |
status_code=429, | |
content={"detail": "Too many requests. Please try again later."} | |
) | |
import gradio as gr | |
# Gradio interface for visualization | |
def gradio_visualize(file, prompt, style="seaborn-v0_8"): | |
# Call your existing FastAPI endpoint | |
with open(file.name, "rb") as f: | |
response = client.post( | |
"/visualize/natural", | |
files={"file": f}, | |
data={"prompt": prompt, "style": style} | |
) | |
result = response.json() | |
# Return both image and code | |
return ( | |
result["image"], # Base64 image | |
f"```python\n{result['code']}\n```" # Code with Markdown formatting | |
) | |
# Create Gradio interface | |
visualization_interface = gr.Interface( | |
fn=gradio_visualize, | |
inputs=[ | |
gr.File(label="Upload Excel File", type="filepath"), | |
gr.Textbox(label="Visualization Prompt", placeholder="e.g., 'Show sales by region'"), | |
gr.Dropdown(label="Style", choices=plt.style.available, value="seaborn-v0_8") | |
], | |
outputs=[ | |
gr.Image(label="Generated Visualization"), # Auto-handles base64 | |
gr.Markdown(label="Generated Code") # Renders code with syntax highlighting | |
], | |
title="📊 Data Visualizer", | |
description="Upload an Excel file and describe the visualization you want" | |
) | |
# Mount Gradio to your FastAPI app | |
app = gr.mount_gradio_app(app, visualization_interface, path="/gradio") | |
# ===== ADD THIS AT THE BOTTOM OF main.py ===== | |
if __name__ == "__main__": | |
import uvicorn | |
from fastapi.testclient import TestClient | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
# 1. Start the app (or connect to a running instance) | |
client = TestClient(app) | |
# 2. Test the visualization endpoint | |
test_file = "test.xlsx" # Replace with your test file | |
test_prompt = "Show me a bar chart of sales by region" | |
# 3. Send request to your own API | |
with open(test_file, "rb") as f: | |
response = client.post( | |
"/visualize/natural", | |
files={"file": ("test.xlsx", f, "application/vnd.ms-excel")}, | |
data={"prompt": test_prompt} | |
) | |
# 4. Check if successful | |
if response.status_code == 200: | |
result = response.json() | |
print("Visualization generated successfully!") | |
# 5. Decode and display the image | |
image_data = result["image"].split(",")[1] # Remove header | |
image_bytes = base64.b64decode(image_data) | |
image = Image.open(BytesIO(image_bytes)) | |
plt.imshow(image) | |
plt.axis("off") | |
plt.show() | |
else: | |
print(f"Error: {response.status_code}\n{response.text}") | |
# 6. Optional: Run the server (if not already running) | |
uvicorn.run(app, host="0.0.0.0", port=7860) |