chenguittiMaroua commited on
Commit
5b32457
·
verified ·
1 Parent(s): 9333c98

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +93 -161
main.py CHANGED
@@ -432,78 +432,70 @@ def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest)
432
  )
433
  from typing import Optional
434
 
435
- def interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]:
436
- """Convert natural language prompt to visualization parameters with enhanced parsing"""
437
  if not prompt or not df_columns:
438
  return None
439
 
440
- # Normalize the prompt and columns
441
  prompt = prompt.lower().strip()
442
- normalized_columns = [col.lower().strip() for col in df_columns]
443
 
444
- # Initialize default values
445
  chart_type = "bar"
446
  x_col = None
447
  y_col = None
448
  hue_col = None
449
- title = f"Visualization of {prompt[:50]}" # Default title
450
 
451
- # Common chart type detection
452
- chart_keywords = {
453
- "line": ["line", "trend", "over time"],
454
- "bar": ["bar", "compare", "comparison"],
455
- "scatter": ["scatter", "correlation", "relationship"],
456
- "histogram": ["histogram", "distribution", "frequency"],
457
- "boxplot": ["box", "quartile", "distribution"],
458
- "heatmap": ["heatmap", "correlation", "matrix"]
459
- }
460
-
461
- # Detect chart type
462
- for chart, keywords in chart_keywords.items():
463
- if any(keyword in prompt for keyword in keywords):
464
- chart_type = chart
465
- break
466
 
467
- # Column detection with improved matching
468
- for col in df_columns:
469
- col_lower = col.lower()
470
-
471
- # Check if column name appears in prompt
472
  if col_lower in prompt:
473
- # Look for context clues about the column's role
474
- if not x_col and ("by " + col_lower in prompt or
475
- "for " + col_lower in prompt or
476
- "across " + col_lower in prompt):
477
  x_col = col
478
- elif not y_col and ("of " + col_lower in prompt or
479
- "show " + col_lower in prompt or
480
- "plot " + col_lower in prompt):
481
  y_col = col
482
- elif not hue_col and ("color by " + col_lower in prompt or
483
- "group by " + col_lower in prompt):
484
  hue_col = col
485
 
486
- # Fallback logic if columns not detected
487
- if not x_col and len(df_columns) > 0:
488
  x_col = df_columns[0] # First column as default x-axis
489
 
490
- if not y_col and len(df_columns) > 1:
491
- # Try to find a numeric column for y-axis
492
- numeric_cols = [col for col in df_columns if pd.api.types.is_numeric_dtype(df[col])]
493
- y_col = numeric_cols[0] if numeric_cols else df_columns[1]
494
 
495
- # Special handling for certain chart types
496
  if chart_type == "heatmap":
497
- x_col = None
 
 
 
 
 
 
 
498
  y_col = None
499
- hue_col = None
500
 
501
  return VisualizationRequest(
502
  chart_type=chart_type,
503
  x_column=x_col,
504
  y_column=y_col,
505
  hue_column=hue_col,
506
- title=title,
507
  style="seaborn-v0_8"
508
  )
509
 
@@ -839,142 +831,82 @@ from fastapi.responses import FileResponse # Add this import at the top
839
 
840
 
841
  # [Previous imports remain exactly the same...]
842
-
843
  @app.post("/visualize/natural")
844
- @limiter.limit("5/minute")
845
  async def visualize_with_natural_language(
846
- request: Request,
847
  file: UploadFile = File(...),
848
  prompt: str = Form(""),
849
  style: str = Form("seaborn-v0_8")
850
  ):
851
  try:
852
- # Debugging: Log incoming request
853
- logger.info(f"Incoming request with file: {file.filename if file else 'None'}")
854
-
855
- # Verify file exists and has content
856
- if not file or not file.filename:
857
- logger.error("No file uploaded")
858
- raise HTTPException(400, "Please upload an Excel file")
859
-
860
- # Read file content
861
  content = await file.read()
862
- if not content:
863
- logger.error("Empty file uploaded")
864
- raise HTTPException(400, "The uploaded file is empty")
865
-
866
- # Verify Excel file extension
867
- file_ext = file.filename.split('.')[-1].lower()
868
- if file_ext not in {"xlsx", "xls"}:
869
- logger.error(f"Unsupported file type: {file_ext}")
870
- raise HTTPException(400, "Only Excel files (.xlsx, .xls) are supported")
871
-
872
- # Read Excel file with multiple engine fallbacks
873
- try:
874
- df = pd.read_excel(BytesIO(content), engine='openpyxl')
875
- except Exception as e:
876
- logger.warning(f"Openpyxl failed, trying xlrd: {str(e)}")
877
- try:
878
- df = pd.read_excel(BytesIO(content), engine='xlrd')
879
- except Exception as e:
880
- logger.error(f"Excel read failed: {str(e)}")
881
- raise HTTPException(400, "Failed to read Excel file - may be corrupt or password protected")
882
-
883
  if df.empty:
884
- logger.error("Empty DataFrame after reading Excel")
885
  raise HTTPException(400, "Excel file contains no data")
886
-
 
 
 
887
  # Generate prompt if empty
888
  if not prompt.strip():
889
- prompt = generate_smart_prompt(df)
890
- logger.info(f"Auto-generated prompt: {prompt}")
891
-
892
- # Create visualization request
893
- try:
894
- vis_request = interpret_natural_language(prompt, df.columns.tolist())
895
- if not vis_request:
896
- raise ValueError("Could not interpret visualization request")
897
 
898
- # Validate columns exist in DataFrame
899
- if vis_request.x_column and vis_request.x_column not in df.columns:
900
- raise ValueError(f"X-axis column '{vis_request.x_column}' not found")
901
-
902
- if vis_request.y_column and vis_request.y_column not in df.columns:
903
- raise ValueError(f"Y-axis column '{vis_request.y_column}' not found")
904
-
905
- if vis_request.hue_column and vis_request.hue_column not in df.columns:
906
- raise ValueError(f"Hue column '{vis_request.hue_column}' not found")
907
-
908
- except ValueError as e:
909
- logger.error(f"Visualization interpretation failed: {str(e)}")
910
- raise HTTPException(
911
- status_code=400,
912
- detail={
913
- "error": "Could not create visualization request",
914
- "message": str(e),
915
- "available_columns": list(df.columns),
916
- "your_prompt": prompt
917
- }
918
- )
919
 
920
  vis_request.style = style
921
 
922
  # Generate visualization
923
- try:
924
- visualization_code = generate_dynamic_visualization_code(df, vis_request)
925
-
926
- plt.style.use(vis_request.style)
927
- fig, ax = plt.subplots(figsize=(10, 6))
928
-
929
- # Safe execution with limited globals
930
- exec_globals = {
931
- 'plt': plt,
932
- 'sns': sns,
933
- 'df': df,
934
- 'np': np,
935
- 'pd': pd
936
- }
937
- exec(visualization_code, exec_globals)
938
-
939
- buffer = BytesIO()
940
- plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
941
- plt.close()
942
- buffer.seek(0)
943
 
944
- return {
945
- "status": "success",
946
- "image_data": base64.b64encode(buffer.getvalue()).decode('utf-8'),
947
- "code": visualization_code,
948
- "columns": list(df.columns),
949
- "prompt": prompt,
950
- "chart_type": vis_request.chart_type,
951
- "x_column": vis_request.x_column,
952
- "y_column": vis_request.y_column,
953
- "hue_column": vis_request.hue_column
954
  }
955
-
956
- except Exception as e:
957
- logger.error(f"Visualization generation failed: {str(e)}")
958
- raise HTTPException(
959
- status_code=400,
960
- detail={
961
- "error": "Visualization generation failed",
962
- "message": str(e),
963
- "suggestion": "Try modifying your prompt or using different columns"
964
- }
965
- )
966
-
967
- except HTTPException as he:
968
- raise
 
 
 
 
 
 
969
  except Exception as e:
970
- logger.error(f"Unexpected error: {traceback.format_exc()}")
971
- raise HTTPException(
972
- status_code=500,
973
- detail={
974
- "error": "Internal server error",
975
- "message": str(e)
976
- }
977
- )
978
 
979
 
980
 
 
432
  )
433
  from typing import Optional
434
 
435
+ ddef interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]:
436
+ """Fully dynamic prompt interpretation that works with any Excel columns"""
437
  if not prompt or not df_columns:
438
  return None
439
 
 
440
  prompt = prompt.lower().strip()
441
+ col_names = [col.lower() for col in df_columns]
442
 
443
+ # Initialize with defaults
444
  chart_type = "bar"
445
  x_col = None
446
  y_col = None
447
  hue_col = None
 
448
 
449
+ # Dynamic chart type detection
450
+ if any(word in prompt for word in ["line", "trend", "over time"]):
451
+ chart_type = "line"
452
+ elif any(word in prompt for word in ["scatter", "relationship", "correlat"]):
453
+ chart_type = "scatter"
454
+ elif any(word in prompt for word in ["histogram", "distribut", "frequenc"]):
455
+ chart_type = "histogram"
456
+ elif any(word in prompt for word in ["box", "quartile"]):
457
+ chart_type = "boxplot"
458
+ elif any(word in prompt for word in ["heatmap", "matrix"]):
459
+ chart_type = "heatmap"
 
 
 
 
460
 
461
+ # Dynamic column assignment - looks for column names mentioned in prompt
462
+ for col, col_lower in zip(df_columns, col_names):
 
 
 
463
  if col_lower in prompt:
464
+ # First mentioned column becomes x-axis
465
+ if not x_col:
 
 
466
  x_col = col
467
+ # Second mentioned becomes y-axis (except for histograms)
468
+ elif not y_col and chart_type != "histogram":
 
469
  y_col = col
470
+ # Third mentioned could be hue
471
+ elif not hue_col and chart_type in ["bar", "scatter", "line"]:
472
  hue_col = col
473
 
474
+ # Smart defaults when columns aren't specified
475
+ if not x_col and df_columns:
476
  x_col = df_columns[0] # First column as default x-axis
477
 
478
+ if not y_col and len(df_columns) > 1 and chart_type != "histogram":
479
+ y_col = df_columns[1] # Second column as default y-axis
 
 
480
 
481
+ # Special handling for specific chart types
482
  if chart_type == "heatmap":
483
+ return VisualizationRequest(
484
+ chart_type="heatmap",
485
+ title=f"Heatmap: {prompt[:30]}...",
486
+ style="seaborn-v0_8"
487
+ )
488
+
489
+ if chart_type == "histogram" and y_col:
490
+ # Histograms only need x-axis
491
  y_col = None
 
492
 
493
  return VisualizationRequest(
494
  chart_type=chart_type,
495
  x_column=x_col,
496
  y_column=y_col,
497
  hue_column=hue_col,
498
+ title=f"{chart_type.title()} of {prompt[:30]}...",
499
  style="seaborn-v0_8"
500
  )
501
 
 
831
 
832
 
833
  # [Previous imports remain exactly the same...]
 
834
  @app.post("/visualize/natural")
 
835
  async def visualize_with_natural_language(
 
836
  file: UploadFile = File(...),
837
  prompt: str = Form(""),
838
  style: str = Form("seaborn-v0_8")
839
  ):
840
  try:
841
+ # Read and validate Excel file
 
 
 
 
 
 
 
 
842
  content = await file.read()
843
+ df = pd.read_excel(BytesIO(content))
844
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
845
  if df.empty:
 
846
  raise HTTPException(400, "Excel file contains no data")
847
+
848
+ # Clean column names (remove special characters)
849
+ df.columns = [str(col).strip() for col in df.columns]
850
+
851
  # Generate prompt if empty
852
  if not prompt.strip():
853
+ prompt = f"Visualize {', '.join(df.columns[:2])}" # Default to first two columns
 
 
 
 
 
 
 
854
 
855
+ # Create visualization request
856
+ vis_request = interpret_natural_language(prompt, list(df.columns))
857
+ if not vis_request:
858
+ raise HTTPException(400, "Couldn't understand your request. Try mentioning column names.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
859
 
860
  vis_request.style = style
861
 
862
  # Generate visualization
863
+ plt.style.use(style)
864
+ fig, ax = plt.subplots(figsize=(10, 6))
865
+
866
+ # Dynamic visualization based on chart type
867
+ if vis_request.chart_type == "heatmap":
868
+ numeric_df = df.select_dtypes(include=['number'])
869
+ if numeric_df.empty:
870
+ raise HTTPException(400, "No numeric columns found for heatmap")
871
+ sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm')
872
+ else:
873
+ # For other chart types
874
+ plot_func = {
875
+ "line": sns.lineplot,
876
+ "bar": sns.barplot,
877
+ "scatter": sns.scatterplot,
878
+ "histogram": lambda data, x, **kwargs: plt.hist(data[x].dropna()),
879
+ "boxplot": sns.boxplot
880
+ }[vis_request.chart_type]
 
 
881
 
882
+ plot_kwargs = {
883
+ "data": df,
884
+ "x": vis_request.x_column,
885
+ "y": vis_request.y_column if vis_request.chart_type != "histogram" else None,
886
+ "hue": vis_request.hue_column
 
 
 
 
 
887
  }
888
+ plot_func(**{k: v for k, v in plot_kwargs.items() if v is not None})
889
+
890
+ plt.title(vis_request.title)
891
+ plt.tight_layout()
892
+
893
+ # Save to buffer
894
+ buffer = BytesIO()
895
+ plt.savefig(buffer, format='png', dpi=150, bbox_inches='tight')
896
+ plt.close()
897
+ buffer.seek(0)
898
+
899
+ return {
900
+ "image": base64.b64encode(buffer.getvalue()).decode('utf-8'),
901
+ "chart_type": vis_request.chart_type,
902
+ "x_column": vis_request.x_column,
903
+ "y_column": vis_request.y_column,
904
+ "hue_column": vis_request.hue_column,
905
+ "columns": list(df.columns)
906
+ }
907
+
908
  except Exception as e:
909
+ raise HTTPException(500, f"Error generating visualization: {str(e)}")
 
 
 
 
 
 
 
910
 
911
 
912