Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
@@ -432,78 +432,70 @@ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest)
|
|
432 |
)
|
433 |
from typing import Optional
|
434 |
|
435 |
-
|
436 |
-
"""
|
437 |
if not prompt or not df_columns:
|
438 |
return None
|
439 |
|
440 |
-
# Normalize the prompt and columns
|
441 |
prompt = prompt.lower().strip()
|
442 |
-
|
443 |
|
444 |
-
# Initialize
|
445 |
chart_type = "bar"
|
446 |
x_col = None
|
447 |
y_col = None
|
448 |
hue_col = None
|
449 |
-
title = f"Visualization of {prompt[:50]}" # Default title
|
450 |
|
451 |
-
#
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
for chart, keywords in chart_keywords.items():
|
463 |
-
if any(keyword in prompt for keyword in keywords):
|
464 |
-
chart_type = chart
|
465 |
-
break
|
466 |
|
467 |
-
#
|
468 |
-
for col in df_columns:
|
469 |
-
col_lower = col.lower()
|
470 |
-
|
471 |
-
# Check if column name appears in prompt
|
472 |
if col_lower in prompt:
|
473 |
-
#
|
474 |
-
if not x_col
|
475 |
-
"for " + col_lower in prompt or
|
476 |
-
"across " + col_lower in prompt):
|
477 |
x_col = col
|
478 |
-
|
479 |
-
|
480 |
-
"plot " + col_lower in prompt):
|
481 |
y_col = col
|
482 |
-
|
483 |
-
|
484 |
hue_col = col
|
485 |
|
486 |
-
#
|
487 |
-
if not x_col and
|
488 |
x_col = df_columns[0] # First column as default x-axis
|
489 |
|
490 |
-
if not y_col and len(df_columns) > 1:
|
491 |
-
|
492 |
-
numeric_cols = [col for col in df_columns if pd.api.types.is_numeric_dtype(df[col])]
|
493 |
-
y_col = numeric_cols[0] if numeric_cols else df_columns[1]
|
494 |
|
495 |
-
# Special handling for
|
496 |
if chart_type == "heatmap":
|
497 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
498 |
y_col = None
|
499 |
-
hue_col = None
|
500 |
|
501 |
return VisualizationRequest(
|
502 |
chart_type=chart_type,
|
503 |
x_column=x_col,
|
504 |
y_column=y_col,
|
505 |
hue_column=hue_col,
|
506 |
-
title=title,
|
507 |
style="seaborn-v0_8"
|
508 |
)
|
509 |
|
@@ -839,142 +831,82 @@ from fastapi.responses import FileResponse # Add this import at the top
|
|
839 |
|
840 |
|
841 |
# [Previous imports remain exactly the same...]
|
842 |
-
|
843 |
@app.post("/visualize/natural")
|
844 |
-
@limiter.limit("5/minute")
|
845 |
async def visualize_with_natural_language(
|
846 |
-
request: Request,
|
847 |
file: UploadFile = File(...),
|
848 |
prompt: str = Form(""),
|
849 |
style: str = Form("seaborn-v0_8")
|
850 |
):
|
851 |
try:
|
852 |
-
#
|
853 |
-
logger.info(f"Incoming request with file: {file.filename if file else 'None'}")
|
854 |
-
|
855 |
-
# Verify file exists and has content
|
856 |
-
if not file or not file.filename:
|
857 |
-
logger.error("No file uploaded")
|
858 |
-
raise HTTPException(400, "Please upload an Excel file")
|
859 |
-
|
860 |
-
# Read file content
|
861 |
content = await file.read()
|
862 |
-
|
863 |
-
|
864 |
-
raise HTTPException(400, "The uploaded file is empty")
|
865 |
-
|
866 |
-
# Verify Excel file extension
|
867 |
-
file_ext = file.filename.split('.')[-1].lower()
|
868 |
-
if file_ext not in {"xlsx", "xls"}:
|
869 |
-
logger.error(f"Unsupported file type: {file_ext}")
|
870 |
-
raise HTTPException(400, "Only Excel files (.xlsx, .xls) are supported")
|
871 |
-
|
872 |
-
# Read Excel file with multiple engine fallbacks
|
873 |
-
try:
|
874 |
-
df = pd.read_excel(BytesIO(content), engine='openpyxl')
|
875 |
-
except Exception as e:
|
876 |
-
logger.warning(f"Openpyxl failed, trying xlrd: {str(e)}")
|
877 |
-
try:
|
878 |
-
df = pd.read_excel(BytesIO(content), engine='xlrd')
|
879 |
-
except Exception as e:
|
880 |
-
logger.error(f"Excel read failed: {str(e)}")
|
881 |
-
raise HTTPException(400, "Failed to read Excel file - may be corrupt or password protected")
|
882 |
-
|
883 |
if df.empty:
|
884 |
-
logger.error("Empty DataFrame after reading Excel")
|
885 |
raise HTTPException(400, "Excel file contains no data")
|
886 |
-
|
|
|
|
|
|
|
887 |
# Generate prompt if empty
|
888 |
if not prompt.strip():
|
889 |
-
prompt =
|
890 |
-
logger.info(f"Auto-generated prompt: {prompt}")
|
891 |
-
|
892 |
-
# Create visualization request
|
893 |
-
try:
|
894 |
-
vis_request = interpret_natural_language(prompt, df.columns.tolist())
|
895 |
-
if not vis_request:
|
896 |
-
raise ValueError("Could not interpret visualization request")
|
897 |
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
if vis_request.y_column and vis_request.y_column not in df.columns:
|
903 |
-
raise ValueError(f"Y-axis column '{vis_request.y_column}' not found")
|
904 |
-
|
905 |
-
if vis_request.hue_column and vis_request.hue_column not in df.columns:
|
906 |
-
raise ValueError(f"Hue column '{vis_request.hue_column}' not found")
|
907 |
-
|
908 |
-
except ValueError as e:
|
909 |
-
logger.error(f"Visualization interpretation failed: {str(e)}")
|
910 |
-
raise HTTPException(
|
911 |
-
status_code=400,
|
912 |
-
detail={
|
913 |
-
"error": "Could not create visualization request",
|
914 |
-
"message": str(e),
|
915 |
-
"available_columns": list(df.columns),
|
916 |
-
"your_prompt": prompt
|
917 |
-
}
|
918 |
-
)
|
919 |
|
920 |
vis_request.style = style
|
921 |
|
922 |
# Generate visualization
|
923 |
-
|
924 |
-
|
925 |
-
|
926 |
-
|
927 |
-
|
928 |
-
|
929 |
-
|
930 |
-
|
931 |
-
|
932 |
-
|
933 |
-
|
934 |
-
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
plt.close()
|
942 |
-
buffer.seek(0)
|
943 |
|
944 |
-
|
945 |
-
"
|
946 |
-
"
|
947 |
-
"
|
948 |
-
"
|
949 |
-
"prompt": prompt,
|
950 |
-
"chart_type": vis_request.chart_type,
|
951 |
-
"x_column": vis_request.x_column,
|
952 |
-
"y_column": vis_request.y_column,
|
953 |
-
"hue_column": vis_request.hue_column
|
954 |
}
|
955 |
-
|
956 |
-
|
957 |
-
|
958 |
-
|
959 |
-
|
960 |
-
|
961 |
-
|
962 |
-
|
963 |
-
|
964 |
-
|
965 |
-
|
966 |
-
|
967 |
-
|
968 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
969 |
except Exception as e:
|
970 |
-
|
971 |
-
raise HTTPException(
|
972 |
-
status_code=500,
|
973 |
-
detail={
|
974 |
-
"error": "Internal server error",
|
975 |
-
"message": str(e)
|
976 |
-
}
|
977 |
-
)
|
978 |
|
979 |
|
980 |
|
|
|
432 |
)
|
433 |
from typing import Optional
|
434 |
|
435 |
+
ddef interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]:
|
436 |
+
"""Fully dynamic prompt interpretation that works with any Excel columns"""
|
437 |
if not prompt or not df_columns:
|
438 |
return None
|
439 |
|
|
|
440 |
prompt = prompt.lower().strip()
|
441 |
+
col_names = [col.lower() for col in df_columns]
|
442 |
|
443 |
+
# Initialize with defaults
|
444 |
chart_type = "bar"
|
445 |
x_col = None
|
446 |
y_col = None
|
447 |
hue_col = None
|
|
|
448 |
|
449 |
+
# Dynamic chart type detection
|
450 |
+
if any(word in prompt for word in ["line", "trend", "over time"]):
|
451 |
+
chart_type = "line"
|
452 |
+
elif any(word in prompt for word in ["scatter", "relationship", "correlat"]):
|
453 |
+
chart_type = "scatter"
|
454 |
+
elif any(word in prompt for word in ["histogram", "distribut", "frequenc"]):
|
455 |
+
chart_type = "histogram"
|
456 |
+
elif any(word in prompt for word in ["box", "quartile"]):
|
457 |
+
chart_type = "boxplot"
|
458 |
+
elif any(word in prompt for word in ["heatmap", "matrix"]):
|
459 |
+
chart_type = "heatmap"
|
|
|
|
|
|
|
|
|
460 |
|
461 |
+
# Dynamic column assignment - looks for column names mentioned in prompt
|
462 |
+
for col, col_lower in zip(df_columns, col_names):
|
|
|
|
|
|
|
463 |
if col_lower in prompt:
|
464 |
+
# First mentioned column becomes x-axis
|
465 |
+
if not x_col:
|
|
|
|
|
466 |
x_col = col
|
467 |
+
# Second mentioned becomes y-axis (except for histograms)
|
468 |
+
elif not y_col and chart_type != "histogram":
|
|
|
469 |
y_col = col
|
470 |
+
# Third mentioned could be hue
|
471 |
+
elif not hue_col and chart_type in ["bar", "scatter", "line"]:
|
472 |
hue_col = col
|
473 |
|
474 |
+
# Smart defaults when columns aren't specified
|
475 |
+
if not x_col and df_columns:
|
476 |
x_col = df_columns[0] # First column as default x-axis
|
477 |
|
478 |
+
if not y_col and len(df_columns) > 1 and chart_type != "histogram":
|
479 |
+
y_col = df_columns[1] # Second column as default y-axis
|
|
|
|
|
480 |
|
481 |
+
# Special handling for specific chart types
|
482 |
if chart_type == "heatmap":
|
483 |
+
return VisualizationRequest(
|
484 |
+
chart_type="heatmap",
|
485 |
+
title=f"Heatmap: {prompt[:30]}...",
|
486 |
+
style="seaborn-v0_8"
|
487 |
+
)
|
488 |
+
|
489 |
+
if chart_type == "histogram" and y_col:
|
490 |
+
# Histograms only need x-axis
|
491 |
y_col = None
|
|
|
492 |
|
493 |
return VisualizationRequest(
|
494 |
chart_type=chart_type,
|
495 |
x_column=x_col,
|
496 |
y_column=y_col,
|
497 |
hue_column=hue_col,
|
498 |
+
title=f"{chart_type.title()} of {prompt[:30]}...",
|
499 |
style="seaborn-v0_8"
|
500 |
)
|
501 |
|
|
|
831 |
|
832 |
|
833 |
# [Previous imports remain exactly the same...]
|
|
|
834 |
@app.post("/visualize/natural")
|
|
|
835 |
async def visualize_with_natural_language(
|
|
|
836 |
file: UploadFile = File(...),
|
837 |
prompt: str = Form(""),
|
838 |
style: str = Form("seaborn-v0_8")
|
839 |
):
|
840 |
try:
|
841 |
+
# Read and validate Excel file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
842 |
content = await file.read()
|
843 |
+
df = pd.read_excel(BytesIO(content))
|
844 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
845 |
if df.empty:
|
|
|
846 |
raise HTTPException(400, "Excel file contains no data")
|
847 |
+
|
848 |
+
# Clean column names (remove special characters)
|
849 |
+
df.columns = [str(col).strip() for col in df.columns]
|
850 |
+
|
851 |
# Generate prompt if empty
|
852 |
if not prompt.strip():
|
853 |
+
prompt = f"Visualize {', '.join(df.columns[:2])}" # Default to first two columns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
854 |
|
855 |
+
# Create visualization request
|
856 |
+
vis_request = interpret_natural_language(prompt, list(df.columns))
|
857 |
+
if not vis_request:
|
858 |
+
raise HTTPException(400, "Couldn't understand your request. Try mentioning column names.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
859 |
|
860 |
vis_request.style = style
|
861 |
|
862 |
# Generate visualization
|
863 |
+
plt.style.use(style)
|
864 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
865 |
+
|
866 |
+
# Dynamic visualization based on chart type
|
867 |
+
if vis_request.chart_type == "heatmap":
|
868 |
+
numeric_df = df.select_dtypes(include=['number'])
|
869 |
+
if numeric_df.empty:
|
870 |
+
raise HTTPException(400, "No numeric columns found for heatmap")
|
871 |
+
sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm')
|
872 |
+
else:
|
873 |
+
# For other chart types
|
874 |
+
plot_func = {
|
875 |
+
"line": sns.lineplot,
|
876 |
+
"bar": sns.barplot,
|
877 |
+
"scatter": sns.scatterplot,
|
878 |
+
"histogram": lambda data, x, **kwargs: plt.hist(data[x].dropna()),
|
879 |
+
"boxplot": sns.boxplot
|
880 |
+
}[vis_request.chart_type]
|
|
|
|
|
881 |
|
882 |
+
plot_kwargs = {
|
883 |
+
"data": df,
|
884 |
+
"x": vis_request.x_column,
|
885 |
+
"y": vis_request.y_column if vis_request.chart_type != "histogram" else None,
|
886 |
+
"hue": vis_request.hue_column
|
|
|
|
|
|
|
|
|
|
|
887 |
}
|
888 |
+
plot_func(**{k: v for k, v in plot_kwargs.items() if v is not None})
|
889 |
+
|
890 |
+
plt.title(vis_request.title)
|
891 |
+
plt.tight_layout()
|
892 |
+
|
893 |
+
# Save to buffer
|
894 |
+
buffer = BytesIO()
|
895 |
+
plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
|
896 |
+
plt.close()
|
897 |
+
buffer.seek(0)
|
898 |
+
|
899 |
+
return {
|
900 |
+
"image": base64.b64encode(buffer.getvalue()).decode('utf-8'),
|
901 |
+
"chart_type": vis_request.chart_type,
|
902 |
+
"x_column": vis_request.x_column,
|
903 |
+
"y_column": vis_request.y_column,
|
904 |
+
"hue_column": vis_request.hue_column,
|
905 |
+
"columns": list(df.columns)
|
906 |
+
}
|
907 |
+
|
908 |
except Exception as e:
|
909 |
+
raise HTTPException(500, f"Error generating visualization: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
910 |
|
911 |
|
912 |
|