Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from transformers import pipeline | |
from typing import Tuple, Optional | |
import io | |
import fitz # PyMuPDF | |
from PIL import Image | |
import pandas as pd | |
import uvicorn | |
from docx import Document | |
from pptx import Presentation | |
import pytesseract | |
import logging | |
import re | |
from slowapi import Limiter | |
from slowapi.util import get_remote_address | |
from slowapi.errors import RateLimitExceeded | |
from slowapi.middleware import SlowAPIMiddleware | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import tempfile | |
import base64 | |
from io import BytesIO | |
from pydantic import BaseModel | |
import traceback | |
import ast | |
from fastapi.responses import HTMLResponse | |
from fastapi import Request | |
from pathlib import Path | |
from fastapi.staticfiles import StaticFiles | |
import numpy as np # Add this import | |
import pandas as pd | |
from io import BytesIO | |
# main.py | |
# Standard library imports | |
import io | |
import re | |
import logging | |
import tempfile | |
import base64 | |
import warnings | |
from typing import Tuple, Optional | |
from pathlib import Path | |
# Third-party imports | |
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse, HTMLResponse | |
from transformers import pipeline | |
import fitz # PyMuPDF | |
from PIL import Image | |
import pandas as pd | |
import uvicorn | |
from docx import Document | |
from pptx import Presentation | |
import pytesseract | |
from slowapi import Limiter | |
from slowapi.util import get_remote_address | |
from slowapi.errors import RateLimitExceeded | |
from slowapi.middleware import SlowAPIMiddleware | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from pydantic import BaseModel | |
import traceback | |
import ast | |
from openpyxl import Workbook | |
# Suppress openpyxl warnings | |
warnings.filterwarnings("ignore", category=UserWarning, module="openpyxl") | |
# Rest of your code (app setup, routes, etc.)... | |
# Initialize rate limiter | |
limiter = Limiter(key_func=get_remote_address) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Serve static files (frontend) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
def home (): | |
with open("static/indexAI.html","r") as file : | |
return file.read() | |
# Apply rate limiting middleware | |
app.state.limiter = limiter | |
app.add_middleware(SlowAPIMiddleware) | |
# CORS Configuration | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Constants | |
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB | |
SUPPORTED_FILE_TYPES = { | |
"docx", "xlsx", "pptx", "pdf", "jpg", "jpeg", "png" | |
} | |
# Model caching | |
summarizer = None | |
qa_model = None | |
image_captioner = None | |
def get_summarizer(): | |
global summarizer | |
if summarizer is None: | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
return summarizer | |
def get_qa_model(): | |
global qa_model | |
if qa_model is None: | |
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
return qa_model | |
def get_image_captioner(): | |
global image_captioner | |
if image_captioner is None: | |
image_captioner = pipeline("image-to-text", model="Salesforce/blip-image-captioning-large") | |
return image_captioner | |
async def process_uploaded_file(file: UploadFile) -> Tuple[str, bytes]: | |
"""Validate and process uploaded file with special handling for each type""" | |
if not file.filename: | |
raise HTTPException(400, "No filename provided") | |
file_ext = file.filename.split('.')[-1].lower() | |
if file_ext not in SUPPORTED_FILE_TYPES: | |
raise HTTPException(400, f"Unsupported file type. Supported: {', '.join(SUPPORTED_FILE_TYPES)}") | |
content = await file.read() | |
if len(content) > MAX_FILE_SIZE: | |
raise HTTPException(413, f"File too large. Max size: {MAX_FILE_SIZE//1024//1024}MB") | |
# Special validation for PDFs | |
if file_ext == "pdf": | |
try: | |
with fitz.open(stream=content, filetype="pdf") as doc: | |
if doc.is_encrypted: | |
if not doc.authenticate(""): | |
raise ValueError("Encrypted PDF - cannot extract text") | |
if len(doc) > 50: | |
raise ValueError("PDF too large (max 50 pages)") | |
except Exception as e: | |
logger.error(f"PDF validation failed: {str(e)}") | |
raise HTTPException(422, detail=f"Invalid PDF file: {str(e)}") | |
await file.seek(0) # Reset file pointer for processing | |
return file_ext, content | |
def extract_text(content: bytes, file_ext: str) -> str: | |
"""Extract text from various file formats with enhanced Excel support""" | |
try: | |
if file_ext == "docx": | |
doc = Document(io.BytesIO(content)) | |
return "\n".join(para.text for para in doc.paragraphs if para.text.strip()) | |
elif file_ext in {"xlsx", "xls"}: | |
# Improved Excel handling with better NaN and date support | |
df = pd.read_excel( | |
io.BytesIO(content), | |
sheet_name=None, | |
engine='openpyxl', | |
na_values=['', 'NA', 'N/A', 'NaN', 'null'], | |
keep_default_na=False, | |
parse_dates=True | |
) | |
all_text = [] | |
for sheet_name, sheet_data in df.items(): | |
sheet_text = [] | |
# Convert all data to string and handle special types | |
for column in sheet_data.columns: | |
# Handle datetime columns | |
if pd.api.types.is_datetime64_any_dtype(sheet_data[column]): | |
sheet_data[column] = sheet_data[column].dt.strftime('%Y-%m-%d %H:%M:%S') | |
# Convert to string and clean | |
col_text = sheet_data[column].astype(str).replace(['nan', 'None', 'NaT'], '').tolist() | |
sheet_text.extend([x for x in col_text if x.strip()]) | |
all_text.append(f"Sheet: {sheet_name}\n" + "\n".join(sheet_text)) | |
return "\n\n".join(all_text) | |
elif file_ext == "pptx": | |
ppt = Presentation(io.BytesIO(content)) | |
text = [] | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text") and shape.text.strip(): | |
text.append(shape.text) | |
return "\n".join(text) | |
elif file_ext == "pdf": | |
pdf = fitz.open(stream=content, filetype="pdf") | |
return "\n".join(page.get_text("text") for page in pdf) | |
elif file_ext in {"jpg", "jpeg", "png"}: | |
# First try OCR | |
try: | |
image = Image.open(io.BytesIO(content)) | |
text = pytesseract.image_to_string(image, config='--psm 6') | |
if text.strip(): | |
return text | |
# If OCR fails, try image captioning | |
captioner = get_image_captioner() | |
result = captioner(image) | |
return result[0]['generated_text'] | |
except Exception as img_e: | |
logger.error(f"Image processing failed: {str(img_e)}") | |
raise ValueError("Could not extract text or caption from image") | |
except Exception as e: | |
logger.error(f"Text extraction failed for {file_ext}: {str(e)}", exc_info=True) | |
raise HTTPException(422, f"Failed to extract text from {file_ext} file: {str(e)}") | |
# Visualization Models | |
class VisualizationRequest(BaseModel): | |
chart_type: str | |
x_column: Optional[str] = None | |
y_column: Optional[str] = None | |
hue_column: Optional[str] = None | |
title: Optional[str] = None | |
x_label: Optional[str] = None | |
y_label: Optional[str] = None | |
style: str = "seaborn-v0_8" # Updated default | |
filters: Optional[dict] = None | |
class NaturalLanguageRequest(BaseModel): | |
prompt: str | |
style: str = "seaborn-v0_8" | |
def validate_matplotlib_style(style: str) -> str: | |
"""Validate and return a valid matplotlib style""" | |
available_styles = plt.style.available | |
# Map legacy style names to current ones | |
style_mapping = { | |
'seaborn': 'seaborn-v0_8', | |
'seaborn-white': 'seaborn-v0_8-white', | |
'seaborn-dark': 'seaborn-v0_8-dark', | |
# Add other legacy mappings if needed | |
} | |
# Check if it's a legacy name we can map | |
if style in style_mapping: | |
return style_mapping[style] | |
# Check if it's a valid current style | |
if style in available_styles: | |
return style | |
logger.warning(f"Invalid style '{style}'. Available styles: {available_styles}") | |
return "seaborn-v0_8" # Default fallback to current seaborn style | |
def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str: | |
"""Generate Python code for visualization with enhanced NaN handling and type safety""" | |
# Validate style | |
valid_style = validate_matplotlib_style(request.style) | |
# Convert DataFrame to dict with proper NaN handling | |
df_dict = df.where(pd.notnull(df), None).to_dict(orient='list') | |
code_lines = [ | |
"import matplotlib.pyplot as plt", | |
"import seaborn as sns", | |
"import pandas as pd", | |
"import numpy as np", | |
"", | |
"# Data preparation with NaN handling and type conversion", | |
f"raw_data = {df_dict}", | |
"df = pd.DataFrame(raw_data)", | |
"", | |
"# Automatic type conversion and cleaning", | |
"for col in df.columns:", | |
" # Convert strings that should be numeric", | |
" if pd.api.types.is_string_dtype(df[col]):", | |
" try:", | |
" df[col] = pd.to_numeric(df[col])", | |
" continue", | |
" except (ValueError, TypeError):", | |
" pass", | |
" ", | |
" # Convert string dates to datetime", | |
" try:", | |
" df[col] = pd.to_datetime(df[col])", | |
" continue", | |
" except (ValueError, TypeError):", | |
" pass", | |
" ", | |
" # Clean remaining None/NaN values", | |
" df[col] = df[col].where(pd.notnull(df[col]), None)", | |
] | |
# Apply filters if specified (with enhanced safety) | |
if request.filters: | |
filter_conditions = [] | |
for column, condition in request.filters.items(): | |
if isinstance(condition, dict): | |
if 'min' in condition and 'max' in condition: | |
filter_conditions.append( | |
f"(pd.notna(df['{column}']) & " | |
f"(df['{column}'] >= {condition['min']}) & " | |
f"(df['{column}'] <= {condition['max']})" | |
) | |
elif 'values' in condition: | |
values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']]) | |
filter_conditions.append( | |
f"(pd.notna(df['{column}'])) & " | |
f"(df['{column}'].isin([{values}]))" | |
) | |
else: | |
filter_conditions.append( | |
f"(pd.notna(df['{column}'])) & " | |
f"(df['{column}'] == {repr(condition)})" | |
) | |
if filter_conditions: | |
code_lines.extend([ | |
"", | |
"# Apply filters with NaN checking", | |
f"df = df[{' & '.join(filter_conditions)}].copy()" | |
]) | |
code_lines.extend([ | |
"", | |
"# Visualization setup", | |
f"plt.style.use('{valid_style}')", | |
f"plt.figure(figsize=(10, 6))" | |
]) | |
# Chart type specific code (unchanged from your original) | |
if request.chart_type == "line": | |
if request.hue_column: | |
code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "bar": | |
if request.hue_column: | |
code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "scatter": | |
if request.hue_column: | |
code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "histogram": | |
code_lines.append(f"plt.hist(df['{request.x_column}'].dropna(), bins=20)") # Added dropna() | |
elif request.chart_type == "boxplot": | |
if request.hue_column: | |
code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") # Added dropna() | |
else: | |
code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}')") # Added dropna() | |
elif request.chart_type == "heatmap": | |
code_lines.append("numeric_df = df.select_dtypes(include=[np.number])") # Filter numeric only | |
code_lines.append("corr = numeric_df.corr()") | |
code_lines.append("sns.heatmap(corr, annot=True, cmap='coolwarm')") | |
else: | |
raise ValueError(f"Unsupported chart type: {request.chart_type}") | |
# Add labels and title | |
if request.title: | |
code_lines.append(f"plt.title('{request.title}')") | |
if request.x_label: | |
code_lines.append(f"plt.xlabel('{request.x_label}')") | |
if request.y_label: | |
code_lines.append(f"plt.ylabel('{request.y_label}')") | |
code_lines.extend([ | |
"plt.tight_layout()", | |
"plt.show()" | |
]) | |
return "\n".join(code_lines) | |
# Determine chart type | |
chart_type = "bar" | |
if "line" in prompt: | |
chart_type = "line" | |
elif "scatter" in prompt: | |
chart_type = "scatter" | |
elif "histogram" in prompt: | |
chart_type = "histogram" | |
elif "box" in prompt: | |
chart_type = "boxplot" | |
elif "heatmap" in prompt or "correlation" in prompt: | |
chart_type = "heatmap" | |
# Try to detect columns | |
x_col = None | |
y_col = None | |
hue_col = None | |
for col in df_columns: | |
if col.lower() in prompt: | |
if not x_col: | |
x_col = col | |
elif not y_col: | |
y_col = col | |
else: | |
hue_col = col | |
# Default to first columns if not detected | |
if not x_col and len(df_columns) > 0: | |
x_col = df_columns[0] | |
if not y_col and len(df_columns) > 1: | |
y_col = df_columns[1] | |
return VisualizationRequest( | |
chart_type=chart_type, | |
x_column=x_col, | |
y_column=y_col, | |
hue_column=hue_col, | |
title="Generated from: " + prompt[:50] + ("..." if len(prompt) > 50 else ""), | |
style="seaborn-v0_8" # Updated default | |
) | |
from typing import Optional | |
def interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]: | |
"""Fully dynamic prompt interpretation that works with any Excel columns""" | |
if not prompt or not df_columns: | |
return None | |
prompt = prompt.lower().strip() | |
col_names = [col.lower() for col in df_columns] | |
# Initialize with defaults | |
chart_type = "bar" | |
x_col = None | |
y_col = None | |
hue_col = None | |
# Dynamic chart type detection | |
if any(word in prompt for word in ["line", "trend", "over time"]): | |
chart_type = "line" | |
elif any(word in prompt for word in ["scatter", "relationship", "correlat"]): | |
chart_type = "scatter" | |
elif any(word in prompt for word in ["histogram", "distribut", "frequenc"]): | |
chart_type = "histogram" | |
elif any(word in prompt for word in ["box", "quartile"]): | |
chart_type = "boxplot" | |
elif any(word in prompt for word in ["heatmap", "matrix"]): | |
chart_type = "heatmap" | |
# Dynamic column assignment - looks for column names mentioned in prompt | |
for col, col_lower in zip(df_columns, col_names): | |
if col_lower in prompt: | |
# First mentioned column becomes x-axis | |
if not x_col: | |
x_col = col | |
# Second mentioned becomes y-axis (except for histograms) | |
elif not y_col and chart_type != "histogram": | |
y_col = col | |
# Third mentioned could be hue | |
elif not hue_col and chart_type in ["bar", "scatter", "line"]: | |
hue_col = col | |
# Smart defaults when columns aren't specified | |
if not x_col and df_columns: | |
x_col = df_columns[0] # First column as default x-axis | |
if not y_col and len(df_columns) > 1 and chart_type != "histogram": | |
y_col = df_columns[1] # Second column as default y-axis | |
# Special handling for specific chart types | |
if chart_type == "heatmap": | |
return VisualizationRequest( | |
chart_type="heatmap", | |
title=f"Heatmap: {prompt[:30]}...", | |
style="seaborn-v0_8" | |
) | |
if chart_type == "histogram" and y_col: | |
# Histograms only need x-axis | |
y_col = None | |
return VisualizationRequest( | |
chart_type=chart_type, | |
x_column=x_col, | |
y_column=y_col, | |
hue_column=hue_col, | |
title=f"{chart_type.title()} of {prompt[:30]}...", | |
style="seaborn-v0_8" | |
) | |
# ===== DYNAMIC VISUALIZATION FUNCTIONS ===== | |
def read_any_excel(content: bytes) -> pd.DataFrame: | |
"""Read any Excel file with automatic type detection""" | |
try: | |
# First read without parsing dates to detect datetime columns | |
df = pd.read_excel( | |
io.BytesIO(content), | |
engine='openpyxl', | |
dtype=object, # Read everything as object initially | |
na_values=['', '#N/A', '#VALUE!', '#REF!', 'NULL', 'NA', 'N/A'] | |
) | |
# Convert each column to best possible type | |
for col in df.columns: | |
# First try numeric conversion | |
try: | |
df[col] = pd.to_numeric(df[col]) | |
continue | |
except (ValueError, TypeError): | |
pass | |
# Then try datetime with explicit format | |
try: | |
df[col] = pd.to_datetime(df[col], format='mixed') | |
continue | |
except (ValueError, TypeError): | |
pass | |
# Finally clean strings | |
df[col] = df[col].astype(str).str.strip() | |
df[col] = df[col].replace(['nan', 'None', 'NaT', ''], None) | |
return df | |
except Exception as e: | |
logger.error(f"Excel reading failed: {str(e)}") | |
raise HTTPException(422, f"Could not process Excel file: {str(e)}") | |
except Exception as e: | |
logger.error(f"Excel reading failed: {str(e)}") | |
raise HTTPException(422, f"Could not process Excel file: {str(e)}") | |
def clean_and_convert_data(df: pd.DataFrame) -> pd.DataFrame: | |
""" | |
Clean and convert data types in a DataFrame with proper error handling | |
""" | |
df_clean = df.copy() | |
for col in df_clean.columns: | |
# Try numeric conversion with proper error handling | |
try: | |
numeric_vals = pd.to_numeric(df_clean[col]) | |
df_clean[col] = numeric_vals | |
continue # Skip to next column if successful | |
except (ValueError, TypeError): | |
pass | |
# Try datetime conversion with format inference | |
try: | |
# First try ISO format | |
datetime_vals = pd.to_datetime(df_clean[col], format='ISO8601') | |
df_clean[col] = datetime_vals | |
continue | |
except (ValueError, TypeError): | |
try: | |
# Fallback to mixed format | |
datetime_vals = pd.to_datetime(df_clean[col], format='mixed') | |
df_clean[col] = datetime_vals | |
continue | |
except (ValueError, TypeError): | |
pass | |
# Clean string columns | |
if df_clean[col].dtype == object: | |
df_clean[col] = ( | |
df_clean[col] | |
.astype(str) | |
.str.strip() | |
.replace(['nan', 'None', 'NaT', ''], pd.NA) | |
) | |
return df_clean | |
def is_date_like(s: str) -> bool: | |
"""Helper to detect date-like strings""" | |
date_patterns = [ | |
r'\d{4}-\d{2}-\d{2}', # YYYY-MM-DD | |
r'\d{2}/\d{2}/\d{4}', # MM/DD/YYYY | |
r'\d{4}/\d{2}/\d{2}', # YYYY/MM/DD | |
r'\d{2}-\d{2}-\d{4}', # MM-DD-YYYY | |
r'\d{1,2}[./-]\d{1,2}[./-]\d{2,4}', # Various separators | |
r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}' # With time | |
] | |
return any(re.match(p, s) for p in date_patterns) | |
def generate_smart_prompt(df: pd.DataFrame) -> str: | |
"""Generate a sensible default prompt based on data""" | |
numeric_cols = df.select_dtypes(include=np.number).columns.tolist() | |
date_cols = df.select_dtypes(include=['datetime']).columns.tolist() | |
cat_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() | |
if date_cols and numeric_cols: | |
return f"Show line chart of {numeric_cols[0]} over time" | |
elif len(numeric_cols) >= 2 and cat_cols: | |
return f"Compare {numeric_cols[0]} and {numeric_cols[1]} by {cat_cols[0]}" | |
elif numeric_cols: | |
return f"Show distribution of {numeric_cols[0]}" | |
else: | |
return "Show data overview" | |
def generate_dynamic_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str: | |
"""Generate visualization code that adapts to any DataFrame structure""" | |
# Validate style | |
valid_style = validate_matplotlib_style(request.style) | |
# Prepare data with type preservation | |
data_dict = {} | |
type_hints = {} | |
for col in df.columns: | |
if pd.api.types.is_datetime64_any_dtype(df[col]): | |
data_dict[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S').tolist() | |
type_hints[col] = 'datetime' | |
elif pd.api.types.is_numeric_dtype(df[col]): | |
data_dict[col] = df[col].tolist() | |
type_hints[col] = 'numeric' | |
else: | |
data_dict[col] = df[col].astype(str).tolist() | |
type_hints[col] = 'string' | |
code_lines = [ | |
"import matplotlib.pyplot as plt", | |
"import seaborn as sns", | |
"import pandas as pd", | |
"import numpy as np", | |
"from datetime import datetime", | |
"", | |
"# Data reconstruction with type handling", | |
f"raw_data = {data_dict}", | |
"df = pd.DataFrame(raw_data)", | |
"", | |
"# Type conversion based on detected types" | |
] | |
# Add type conversion for each column | |
for col, col_type in type_hints.items(): | |
if col_type == 'datetime': | |
code_lines.append( | |
f"df['{col}'] = pd.to_datetime(df['{col}'], format='%Y-%m-%d %H:%M:%S', errors='ignore')" | |
) | |
elif col_type == 'numeric': | |
code_lines.append( | |
f"df['{col}'] = pd.to_numeric(df['{col}'], errors='ignore')" | |
) | |
code_lines.extend([ | |
"", | |
"# Clean missing values", | |
"df = df.replace([None, np.nan, 'nan', 'None', 'NaT', ''], None)", | |
"df = df.where(pd.notnull(df), None)", | |
"", | |
"# Visualization setup", | |
f"plt.style.use('{valid_style}')", | |
f"plt.figure(figsize=(10, 6))" | |
]) | |
# Chart type specific code (from your existing function) | |
if request.chart_type == "line": | |
if request.hue_column: | |
code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "bar": | |
if request.hue_column: | |
code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "scatter": | |
if request.hue_column: | |
code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])") | |
elif request.chart_type == "histogram": | |
code_lines.append(f"plt.hist(df['{request.x_column}'].dropna(), bins=20)") | |
elif request.chart_type == "boxplot": | |
if request.hue_column: | |
code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") | |
else: | |
code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}')") | |
elif request.chart_type == "heatmap": | |
code_lines.append("numeric_df = df.select_dtypes(include=[np.number])") | |
code_lines.append("corr = numeric_df.corr()") | |
code_lines.append("sns.heatmap(corr, annot=True, cmap='coolwarm')") | |
else: | |
raise ValueError(f"Unsupported chart type: {request.chart_type}") | |
# Add labels and title | |
if request.title: | |
code_lines.append(f"plt.title('{request.title}')") | |
if request.x_label: | |
code_lines.append(f"plt.xlabel('{request.x_label}')") | |
if request.y_label: | |
code_lines.append(f"plt.ylabel('{request.y_label}')") | |
code_lines.extend([ | |
"plt.tight_layout()", | |
"plt.show()" | |
]) | |
return "\n".join(code_lines) | |
async def summarize_document(request: Request, file: UploadFile = File(...)): | |
try: | |
file_ext, content = await process_uploaded_file(file) | |
text = extract_text(content, file_ext) | |
if not text.strip(): | |
raise HTTPException(400, "No extractable text found") | |
# Clean and chunk text | |
text = re.sub(r'\s+', ' ', text).strip() | |
chunks = [text[i:i+1000] for i in range(0, len(text), 1000)] | |
# Summarize each chunk | |
summarizer = get_summarizer() | |
summaries = [] | |
for chunk in chunks: | |
summary = summarizer(chunk, max_length=150, min_length=50, do_sample=False)[0]["summary_text"] | |
summaries.append(summary) | |
return {"summary": " ".join(summaries)} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Summarization failed: {str(e)}") | |
raise HTTPException(500, "Document summarization failed") | |
async def question_answering( | |
request: Request, | |
file: UploadFile = File(...), | |
question: str = Form(...), | |
language: str = Form("fr") | |
): | |
try: | |
file_ext, content = await process_uploaded_file(file) | |
text = extract_text(content, file_ext) | |
if not text.strip(): | |
raise HTTPException(400, "No extractable text found") | |
# Clean and truncate text | |
text = re.sub(r'\s+', ' ', text).strip()[:5000] | |
# Theme detection | |
theme_keywords = ["thème", "sujet principal", "quoi le sujet", "theme", "main topic"] | |
if any(kw in question.lower() for kw in theme_keywords): | |
try: | |
summarizer = get_summarizer() | |
summary_output = summarizer( | |
text, | |
max_length=min(100, len(text)//4), | |
min_length=30, | |
do_sample=False, | |
truncation=True | |
) | |
theme = summary_output[0].get("summary_text", text[:200] + "...") | |
return { | |
"question": question, | |
"answer": f"Le document traite principalement de : {theme}", | |
"confidence": 0.95, | |
"language": language | |
} | |
except Exception: | |
theme = text[:200] + ("..." if len(text) > 200 else "") | |
return { | |
"question": question, | |
"answer": f"D'après le document : {theme}", | |
"confidence": 0.7, | |
"language": language, | |
"warning": "theme_summary_fallback" | |
} | |
# Standard QA | |
qa = get_qa_model() | |
result = qa(question=question, context=text[:3000]) | |
return { | |
"question": question, | |
"answer": result["answer"], | |
"confidence": result["score"], | |
"language": language | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"QA processing failed: {str(e)}") | |
raise HTTPException(500, detail=f"Analysis failed: {str(e)}") | |
# [Previous imports remain exactly the same...] | |
async def natural_language_visualization( | |
file: UploadFile = File(...), | |
prompt: str = Form(""), | |
style: str = Form("seaborn-v0_8") | |
): | |
try: | |
# Read and validate file | |
content = await file.read() | |
try: | |
df = pd.read_excel(BytesIO(content)) | |
except Exception as e: | |
raise HTTPException(400, detail=f"Invalid Excel file: {str(e)}") | |
if df.empty: | |
raise HTTPException(400, detail="The uploaded Excel file is empty") | |
# Clean and convert data types | |
for col in df.columns: | |
# Try numeric conversion first | |
df[col] = pd.to_numeric(df[col], errors='ignore') | |
# Then try datetime | |
try: | |
df[col] = pd.to_datetime(df[col], errors='ignore') | |
except: | |
pass | |
# Finally clean strings | |
df[col] = df[col].astype(str).str.strip().replace('nan', np.nan) | |
# Generate visualization request | |
vis_request = interpret_natural_language(prompt, df.columns.tolist()) | |
if not vis_request: | |
raise HTTPException(400, "Could not interpret visualization request") | |
# Create visualization | |
plt.style.use(style) | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
try: | |
if vis_request.chart_type == "heatmap": | |
numeric_df = df.select_dtypes(include=['number']) | |
if numeric_df.empty: | |
raise ValueError("No numeric columns for heatmap") | |
sns.heatmap(numeric_df.corr(), annot=True, ax=ax) | |
else: | |
# Ensure numeric data for plotting | |
plot_data = df.copy() | |
if vis_request.x_column: | |
plot_data[vis_request.x_column] = pd.to_numeric( | |
plot_data[vis_request.x_column], | |
errors='coerce' | |
) | |
if vis_request.y_column: | |
plot_data[vis_request.y_column] = pd.to_numeric( | |
plot_data[vis_request.y_column], | |
errors='coerce' | |
) | |
# Remove rows with missing numeric data | |
plot_data = plot_data.dropna() | |
if vis_request.chart_type == "line": | |
sns.lineplot( | |
data=plot_data, | |
x=vis_request.x_column, | |
y=vis_request.y_column, | |
hue=vis_request.hue_column, | |
ax=ax | |
) | |
elif vis_request.chart_type == "bar": | |
sns.barplot( | |
data=plot_data, | |
x=vis_request.x_column, | |
y=vis_request.y_column, | |
hue=vis_request.hue_column, | |
ax=ax | |
) | |
elif vis_request.chart_type == "scatter": | |
sns.scatterplot( | |
data=plot_data, | |
x=vis_request.x_column, | |
y=vis_request.y_column, | |
hue=vis_request.hue_column, | |
ax=ax | |
) | |
# Add other chart types as needed... | |
ax.set_title(vis_request.title) | |
buf = BytesIO() | |
plt.savefig(buf, format='png', bbox_inches='tight') | |
plt.close(fig) | |
buf.seek(0) | |
# Generate the visualization code (you'll need to implement this) | |
generated_code = generate_visualization_code(df, vis_request) | |
return { | |
"status": "success", | |
"image": base64.b64encode(buf.read()).decode('utf-8'), | |
"chart_type": vis_request.chart_type, | |
"columns": list(df.columns), | |
"x_column": vis_request.x_column, | |
"y_column": vis_request.y_column, | |
"hue_column": vis_request.hue_column, | |
"code": generated_code # Added comma that was missing | |
} | |
except Exception as e: | |
raise HTTPException(400, detail=f"Plotting error: {str(e)}") | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Visualization error: {str(e)}", exc_info=True) | |
raise HTTPException(500, detail=f"Server error: {str(e)}") | |
async def list_available_styles(request: Request): | |
"""List all available matplotlib styles""" | |
return {"available_styles": plt.style.available} | |
async def get_excel_columns( | |
request: Request, | |
file: UploadFile = File(...) | |
): | |
try: | |
file_ext, content = await process_uploaded_file(file) | |
if file_ext not in {"xlsx", "xls"}: | |
raise HTTPException(400, "Only Excel files are supported") | |
df = pd.read_excel(io.BytesIO(content)) | |
return { | |
"columns": list(df.columns), | |
"sample_data": df.head().to_dict(orient='records'), | |
"statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None | |
} | |
except Exception as e: | |
logger.error(f"Column extraction failed: {str(e)}") | |
raise HTTPException(500, detail="Failed to extract columns from Excel file") | |
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded): | |
return JSONResponse( | |
status_code=429, | |
content={"detail": "Too many requests. Please try again later."} | |
) | |
import gradio as gr | |
# Gradio interface for visualization | |
def gradio_visualize(file, prompt, style="seaborn-v0_8"): | |
# Call your existing FastAPI endpoint | |
with open(file.name, "rb") as f: | |
response = client.post( | |
"/visualize/natural", | |
files={"file": f}, | |
data={"prompt": prompt, "style": style} | |
) | |
result = response.json() | |
# Return both image and code | |
return ( | |
result["image"], # Base64 image | |
f"```python\n{result['code']}\n```" # Code with Markdown formatting | |
) | |
# Create Gradio interface | |
visualization_interface = gr.Interface( | |
fn=gradio_visualize, | |
inputs=[ | |
gr.File(label="Upload Excel File", type="filepath"), | |
gr.Textbox(label="Visualization Prompt", placeholder="e.g., 'Show sales by region'"), | |
gr.Dropdown(label="Style", choices=plt.style.available, value="seaborn-v0_8") | |
], | |
outputs=[ | |
gr.Image(label="Generated Visualization"), # Auto-handles base64 | |
gr.Markdown(label="Generated Code") # Renders code with syntax highlighting | |
], | |
title="📊 Data Visualizer", | |
description="Upload an Excel file and describe the visualization you want" | |
) | |
# Mount Gradio to your FastAPI app | |
app = gr.mount_gradio_app(app, visualization_interface, path="/gradio") | |
# ===== ADD THIS AT THE BOTTOM OF main.py ===== | |
if __name__ == "__main__": | |
import uvicorn | |
from fastapi.testclient import TestClient | |
from io import BytesIO | |
import base64 | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
# 1. Start the app (or connect to a running instance) | |
client = TestClient(app) | |
# 2. Test the visualization endpoint | |
test_file = "test.xlsx" # Replace with your test file | |
test_prompt = "Show me a bar chart of sales by region" | |
# 3. Send request to your own API | |
with open(test_file, "rb") as f: | |
response = client.post( | |
"/visualize/natural", | |
files={"file": ("test.xlsx", f, "application/vnd.ms-excel")}, | |
data={"prompt": test_prompt} | |
) | |
# 4. Check if successful | |
if response.status_code == 200: | |
result = response.json() | |
print("Visualization generated successfully!") | |
# 5. Decode and display the image | |
image_data = result["image"].split(",")[1] # Remove header | |
image_bytes = base64.b64decode(image_data) | |
image = Image.open(BytesIO(image_bytes)) | |
plt.imshow(image) | |
plt.axis("off") | |
plt.show() | |
else: | |
print(f"Error: {response.status_code}\n{response.text}") | |
# 6. Optional: Run the server (if not already running) | |
uvicorn.run(app, host="0.0.0.0", port=7860) |