chenguittiMaroua commited on
Commit
3a67355
·
verified ·
1 Parent(s): a65b358

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +1 -666
main.py CHANGED
@@ -296,496 +296,15 @@ def extract_text(content: bytes, file_ext: str) -> str:
296
 
297
 
298
 
299
- # Visualization Models
300
- class VisualizationRequest(BaseModel):
301
- chart_type: str
302
- x_column: Optional[str] = None
303
- y_column: Optional[str] = None
304
- hue_column: Optional[str] = None
305
- title: Optional[str] = None
306
- x_label: Optional[str] = None
307
- y_label: Optional[str] = None
308
- style: str = "seaborn-v0_8" # Updated default
309
- filters: Optional[dict] = None
310
-
311
- class NaturalLanguageRequest(BaseModel):
312
- prompt: str
313
- style: str = "seaborn-v0_8"
314
-
315
- def validate_matplotlib_style(style: str) -> str:
316
- """Validate and return a valid matplotlib style"""
317
- available_styles = plt.style.available
318
- # Map legacy style names to current ones
319
- style_mapping = {
320
- 'seaborn': 'seaborn-v0_8',
321
- 'seaborn-white': 'seaborn-v0_8-white',
322
- 'seaborn-dark': 'seaborn-v0_8-dark',
323
- # Add other legacy mappings if needed
324
- }
325
-
326
- # Check if it's a legacy name we can map
327
- if style in style_mapping:
328
- return style_mapping[style]
329
-
330
- # Check if it's a valid current style
331
- if style in available_styles:
332
- return style
333
-
334
- logger.warning(f"Invalid style '{style}'. Available styles: {available_styles}")
335
- return "seaborn-v0_8" # Default fallback to current seaborn style
336
 
337
 
338
 
339
-
340
- def generate_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
341
- """Generate Python code for visualization with enhanced NaN handling and type safety"""
342
- # Validate style
343
- valid_style = validate_matplotlib_style(request.style)
344
-
345
- # Convert DataFrame to dict with proper NaN handling
346
- df_dict = df.where(pd.notnull(df), None).to_dict(orient='list')
347
-
348
- code_lines = [
349
- "import matplotlib.pyplot as plt",
350
- "import seaborn as sns",
351
- "import pandas as pd",
352
- "import numpy as np",
353
- "",
354
- "# Data preparation with NaN handling and type conversion",
355
- f"raw_data = {df_dict}",
356
- "df = pd.DataFrame(raw_data)",
357
- "",
358
- "# Automatic type conversion and cleaning",
359
- "for col in df.columns:",
360
- " # Convert strings that should be numeric",
361
- " if pd.api.types.is_string_dtype(df[col]):",
362
- " try:",
363
- " df[col] = pd.to_numeric(df[col])",
364
- " continue",
365
- " except (ValueError, TypeError):",
366
- " pass",
367
- " ",
368
- " # Convert string dates to datetime",
369
- " try:",
370
- " df[col] = pd.to_datetime(df[col])",
371
- " continue",
372
- " except (ValueError, TypeError):",
373
- " pass",
374
- " ",
375
- " # Clean remaining None/NaN values",
376
- " df[col] = df[col].where(pd.notnull(df[col]), None)",
377
- ]
378
-
379
- # Apply filters if specified (with enhanced safety)
380
- if request.filters:
381
- filter_conditions = []
382
- for column, condition in request.filters.items():
383
- if isinstance(condition, dict):
384
- if 'min' in condition and 'max' in condition:
385
- filter_conditions.append(
386
- f"(pd.notna(df['{column}']) & "
387
- f"(df['{column}'] >= {condition['min']}) & "
388
- f"(df['{column}'] <= {condition['max']})"
389
- )
390
- elif 'values' in condition:
391
- values = ', '.join([f"'{v}'" if isinstance(v, str) else str(v) for v in condition['values']])
392
- filter_conditions.append(
393
- f"(pd.notna(df['{column}'])) & "
394
- f"(df['{column}'].isin([{values}]))"
395
- )
396
- else:
397
- filter_conditions.append(
398
- f"(pd.notna(df['{column}'])) & "
399
- f"(df['{column}'] == {repr(condition)})"
400
- )
401
-
402
- if filter_conditions:
403
- code_lines.extend([
404
- "",
405
- "# Apply filters with NaN checking",
406
- f"df = df[{' & '.join(filter_conditions)}].copy()"
407
- ])
408
-
409
- code_lines.extend([
410
- "",
411
- "# Visualization setup",
412
- f"plt.style.use('{valid_style}')",
413
- f"plt.figure(figsize=(10, 6))"
414
- ])
415
-
416
- # Chart type specific code (unchanged from your original)
417
- if request.chart_type == "line":
418
- if request.hue_column:
419
- code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
420
- else:
421
- code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
422
- elif request.chart_type == "bar":
423
- if request.hue_column:
424
- code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
425
- else:
426
- code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
427
- elif request.chart_type == "scatter":
428
- if request.hue_column:
429
- code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
430
- else:
431
- code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
432
- elif request.chart_type == "histogram":
433
- code_lines.append(f"plt.hist(df['{request.x_column}'].dropna(), bins=20)") # Added dropna()
434
- elif request.chart_type == "boxplot":
435
- if request.hue_column:
436
- code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')") # Added dropna()
437
- else:
438
- code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}')") # Added dropna()
439
- elif request.chart_type == "heatmap":
440
- code_lines.append("numeric_df = df.select_dtypes(include=[np.number])") # Filter numeric only
441
- code_lines.append("corr = numeric_df.corr()")
442
- code_lines.append("sns.heatmap(corr, annot=True, cmap='coolwarm')")
443
- else:
444
- raise ValueError(f"Unsupported chart type: {request.chart_type}")
445
-
446
- # Add labels and title
447
- if request.title:
448
- code_lines.append(f"plt.title('{request.title}')")
449
- if request.x_label:
450
- code_lines.append(f"plt.xlabel('{request.x_label}')")
451
- if request.y_label:
452
- code_lines.append(f"plt.ylabel('{request.y_label}')")
453
-
454
- code_lines.extend([
455
- "plt.tight_layout()",
456
- "plt.show()"
457
- ])
458
-
459
- return "\n".join(code_lines)
460
-
461
-
462
- # Determine chart type
463
- chart_type = "bar"
464
- if "line" in prompt:
465
- chart_type = "line"
466
- elif "scatter" in prompt:
467
- chart_type = "scatter"
468
- elif "histogram" in prompt:
469
- chart_type = "histogram"
470
- elif "box" in prompt:
471
- chart_type = "boxplot"
472
- elif "heatmap" in prompt or "correlation" in prompt:
473
- chart_type = "heatmap"
474
-
475
- # Try to detect columns
476
- x_col = None
477
- y_col = None
478
- hue_col = None
479
-
480
- for col in df_columns:
481
- if col.lower() in prompt:
482
- if not x_col:
483
- x_col = col
484
- elif not y_col:
485
- y_col = col
486
- else:
487
- hue_col = col
488
-
489
- # Default to first columns if not detected
490
- if not x_col and len(df_columns) > 0:
491
- x_col = df_columns[0]
492
- if not y_col and len(df_columns) > 1:
493
- y_col = df_columns[1]
494
-
495
- return VisualizationRequest(
496
- chart_type=chart_type,
497
- x_column=x_col,
498
- y_column=y_col,
499
- hue_column=hue_col,
500
- title="Generated from: " + prompt[:50] + ("..." if len(prompt) > 50 else ""),
501
- style="seaborn-v0_8" # Updated default
502
- )
503
- from typing import Optional
504
 
505
- def interpret_natural_language(prompt: str, df_columns: list) -> Optional[VisualizationRequest]:
506
- """Fully dynamic prompt interpretation that works with any Excel columns"""
507
- if not prompt or not df_columns:
508
- return None
509
 
510
- prompt = prompt.lower().strip()
511
- col_names = [col.lower() for col in df_columns]
512
-
513
- # Initialize with defaults
514
- chart_type = "bar"
515
- x_col = None
516
- y_col = None
517
- hue_col = None
518
-
519
- # Dynamic chart type detection
520
- if any(word in prompt for word in ["line", "trend", "over time"]):
521
- chart_type = "line"
522
- elif any(word in prompt for word in ["scatter", "relationship", "correlat"]):
523
- chart_type = "scatter"
524
- elif any(word in prompt for word in ["histogram", "distribut", "frequenc"]):
525
- chart_type = "histogram"
526
- elif any(word in prompt for word in ["box", "quartile"]):
527
- chart_type = "boxplot"
528
- elif any(word in prompt for word in ["heatmap", "matrix"]):
529
- chart_type = "heatmap"
530
-
531
- # Dynamic column assignment - looks for column names mentioned in prompt
532
- for col, col_lower in zip(df_columns, col_names):
533
- if col_lower in prompt:
534
- # First mentioned column becomes x-axis
535
- if not x_col:
536
- x_col = col
537
- # Second mentioned becomes y-axis (except for histograms)
538
- elif not y_col and chart_type != "histogram":
539
- y_col = col
540
- # Third mentioned could be hue
541
- elif not hue_col and chart_type in ["bar", "scatter", "line"]:
542
- hue_col = col
543
-
544
- # Smart defaults when columns aren't specified
545
- if not x_col and df_columns:
546
- x_col = df_columns[0] # First column as default x-axis
547
-
548
- if not y_col and len(df_columns) > 1 and chart_type != "histogram":
549
- y_col = df_columns[1] # Second column as default y-axis
550
-
551
- # Special handling for specific chart types
552
- if chart_type == "heatmap":
553
- return VisualizationRequest(
554
- chart_type="heatmap",
555
- title=f"Heatmap: {prompt[:30]}...",
556
- style="seaborn-v0_8"
557
- )
558
 
559
- if chart_type == "histogram" and y_col:
560
- # Histograms only need x-axis
561
- y_col = None
562
-
563
- return VisualizationRequest(
564
- chart_type=chart_type,
565
- x_column=x_col,
566
- y_column=y_col,
567
- hue_column=hue_col,
568
- title=f"{chart_type.title()} of {prompt[:30]}...",
569
- style="seaborn-v0_8"
570
- )
571
 
572
- # ===== DYNAMIC VISUALIZATION FUNCTIONS =====
573
- def read_any_excel(content: bytes) -> pd.DataFrame:
574
- """Read any Excel file with automatic type detection"""
575
- try:
576
- # First read without parsing dates to detect datetime columns
577
- df = pd.read_excel(
578
- io.BytesIO(content),
579
- engine='openpyxl',
580
- dtype=object, # Read everything as object initially
581
- na_values=['', '#N/A', '#VALUE!', '#REF!', 'NULL', 'NA', 'N/A']
582
- )
583
-
584
- # Convert each column to best possible type
585
- for col in df.columns:
586
- # First try numeric conversion
587
- try:
588
- df[col] = pd.to_numeric(df[col])
589
- continue
590
- except (ValueError, TypeError):
591
- pass
592
-
593
- # Then try datetime with explicit format
594
- try:
595
- df[col] = pd.to_datetime(df[col], format='mixed')
596
- continue
597
- except (ValueError, TypeError):
598
- pass
599
-
600
- # Finally clean strings
601
- df[col] = df[col].astype(str).str.strip()
602
- df[col] = df[col].replace(['nan', 'None', 'NaT', ''], None)
603
-
604
- return df
605
-
606
- except Exception as e:
607
- logger.error(f"Excel reading failed: {str(e)}")
608
- raise HTTPException(422, f"Could not process Excel file: {str(e)}")
609
 
610
 
611
-
612
-
613
- except Exception as e:
614
- logger.error(f"Excel reading failed: {str(e)}")
615
- raise HTTPException(422, f"Could not process Excel file: {str(e)}")
616
-
617
-
618
- def clean_and_convert_data(df: pd.DataFrame) -> pd.DataFrame:
619
- """
620
- Clean and convert data types in a DataFrame with proper error handling
621
- """
622
- df_clean = df.copy()
623
-
624
- for col in df_clean.columns:
625
- # Try numeric conversion with proper error handling
626
- try:
627
- numeric_vals = pd.to_numeric(df_clean[col])
628
- df_clean[col] = numeric_vals
629
- continue # Skip to next column if successful
630
- except (ValueError, TypeError):
631
- pass
632
-
633
- # Try datetime conversion with format inference
634
- try:
635
- # First try ISO format
636
- datetime_vals = pd.to_datetime(df_clean[col], format='ISO8601')
637
- df_clean[col] = datetime_vals
638
- continue
639
- except (ValueError, TypeError):
640
- try:
641
- # Fallback to mixed format
642
- datetime_vals = pd.to_datetime(df_clean[col], format='mixed')
643
- df_clean[col] = datetime_vals
644
- continue
645
- except (ValueError, TypeError):
646
- pass
647
-
648
- # Clean string columns
649
- if df_clean[col].dtype == object:
650
- df_clean[col] = (
651
- df_clean[col]
652
- .astype(str)
653
- .str.strip()
654
- .replace(['nan', 'None', 'NaT', ''], pd.NA)
655
- )
656
-
657
- return df_clean
658
-
659
-
660
-
661
- def is_date_like(s: str) -> bool:
662
- """Helper to detect date-like strings"""
663
- date_patterns = [
664
- r'\d{4}-\d{2}-\d{2}', # YYYY-MM-DD
665
- r'\d{2}/\d{2}/\d{4}', # MM/DD/YYYY
666
- r'\d{4}/\d{2}/\d{2}', # YYYY/MM/DD
667
- r'\d{2}-\d{2}-\d{4}', # MM-DD-YYYY
668
- r'\d{1,2}[./-]\d{1,2}[./-]\d{2,4}', # Various separators
669
- r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}' # With time
670
- ]
671
- return any(re.match(p, s) for p in date_patterns)
672
-
673
- def generate_smart_prompt(df: pd.DataFrame) -> str:
674
- """Generate a sensible default prompt based on data"""
675
- numeric_cols = df.select_dtypes(include=np.number).columns.tolist()
676
- date_cols = df.select_dtypes(include=['datetime']).columns.tolist()
677
- cat_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
678
-
679
- if date_cols and numeric_cols:
680
- return f"Show line chart of {numeric_cols[0]} over time"
681
- elif len(numeric_cols) >= 2 and cat_cols:
682
- return f"Compare {numeric_cols[0]} and {numeric_cols[1]} by {cat_cols[0]}"
683
- elif numeric_cols:
684
- return f"Show distribution of {numeric_cols[0]}"
685
- else:
686
- return "Show data overview"
687
-
688
- def generate_dynamic_visualization_code(df: pd.DataFrame, request: VisualizationRequest) -> str:
689
- """Generate visualization code that adapts to any DataFrame structure"""
690
- # Validate style
691
- valid_style = validate_matplotlib_style(request.style)
692
-
693
- # Prepare data with type preservation
694
- data_dict = {}
695
- type_hints = {}
696
-
697
- for col in df.columns:
698
- if pd.api.types.is_datetime64_any_dtype(df[col]):
699
- data_dict[col] = df[col].dt.strftime('%Y-%m-%d %H:%M:%S').tolist()
700
- type_hints[col] = 'datetime'
701
- elif pd.api.types.is_numeric_dtype(df[col]):
702
- data_dict[col] = df[col].tolist()
703
- type_hints[col] = 'numeric'
704
- else:
705
- data_dict[col] = df[col].astype(str).tolist()
706
- type_hints[col] = 'string'
707
-
708
- code_lines = [
709
- "import matplotlib.pyplot as plt",
710
- "import seaborn as sns",
711
- "import pandas as pd",
712
- "import numpy as np",
713
- "from datetime import datetime",
714
- "",
715
- "# Data reconstruction with type handling",
716
- f"raw_data = {data_dict}",
717
- "df = pd.DataFrame(raw_data)",
718
- "",
719
- "# Type conversion based on detected types"
720
- ]
721
-
722
- # Add type conversion for each column
723
- for col, col_type in type_hints.items():
724
- if col_type == 'datetime':
725
- code_lines.append(
726
- f"df['{col}'] = pd.to_datetime(df['{col}'], format='%Y-%m-%d %H:%M:%S', errors='ignore')"
727
- )
728
- elif col_type == 'numeric':
729
- code_lines.append(
730
- f"df['{col}'] = pd.to_numeric(df['{col}'], errors='ignore')"
731
- )
732
-
733
- code_lines.extend([
734
- "",
735
- "# Clean missing values",
736
- "df = df.replace([None, np.nan, 'nan', 'None', 'NaT', ''], None)",
737
- "df = df.where(pd.notnull(df), None)",
738
- "",
739
- "# Visualization setup",
740
- f"plt.style.use('{valid_style}')",
741
- f"plt.figure(figsize=(10, 6))"
742
- ])
743
-
744
- # Chart type specific code (from your existing function)
745
- if request.chart_type == "line":
746
- if request.hue_column:
747
- code_lines.append(f"sns.lineplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
748
- else:
749
- code_lines.append(f"plt.plot(df['{request.x_column}'], df['{request.y_column}'])")
750
- elif request.chart_type == "bar":
751
- if request.hue_column:
752
- code_lines.append(f"sns.barplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
753
- else:
754
- code_lines.append(f"plt.bar(df['{request.x_column}'], df['{request.y_column}'])")
755
- elif request.chart_type == "scatter":
756
- if request.hue_column:
757
- code_lines.append(f"sns.scatterplot(data=df, x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
758
- else:
759
- code_lines.append(f"plt.scatter(df['{request.x_column}'], df['{request.y_column}'])")
760
- elif request.chart_type == "histogram":
761
- code_lines.append(f"plt.hist(df['{request.x_column}'].dropna(), bins=20)")
762
- elif request.chart_type == "boxplot":
763
- if request.hue_column:
764
- code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}', hue='{request.hue_column}')")
765
- else:
766
- code_lines.append(f"sns.boxplot(data=df.dropna(), x='{request.x_column}', y='{request.y_column}')")
767
- elif request.chart_type == "heatmap":
768
- code_lines.append("numeric_df = df.select_dtypes(include=[np.number])")
769
- code_lines.append("corr = numeric_df.corr()")
770
- code_lines.append("sns.heatmap(corr, annot=True, cmap='coolwarm')")
771
- else:
772
- raise ValueError(f"Unsupported chart type: {request.chart_type}")
773
-
774
- # Add labels and title
775
- if request.title:
776
- code_lines.append(f"plt.title('{request.title}')")
777
- if request.x_label:
778
- code_lines.append(f"plt.xlabel('{request.x_label}')")
779
- if request.y_label:
780
- code_lines.append(f"plt.ylabel('{request.y_label}')")
781
-
782
- code_lines.extend([
783
- "plt.tight_layout()",
784
- "plt.show()"
785
- ])
786
-
787
- return "\n".join(code_lines)
788
-
789
 
790
 
791
 
@@ -990,191 +509,7 @@ def validate_french_response(text: str) -> str:
990
 
991
 
992
 
993
-
994
- @app.post("/visualize/natural")
995
- async def natural_language_visualization(
996
- file: UploadFile = File(...),
997
- prompt: str = Form(""),
998
- style: str = Form("seaborn-v0_8")
999
- ):
1000
- try:
1001
- # Read and validate file
1002
- content = await file.read()
1003
- try:
1004
- df = pd.read_excel(BytesIO(content))
1005
- except Exception as e:
1006
- raise HTTPException(400, detail=f"Invalid Excel file: {str(e)}")
1007
-
1008
- if df.empty:
1009
- raise HTTPException(400, detail="The uploaded Excel file is empty")
1010
-
1011
- # Clean and convert data types
1012
- for col in df.columns:
1013
- # Try numeric conversion first
1014
- df[col] = pd.to_numeric(df[col], errors='ignore')
1015
-
1016
- # Then try datetime
1017
- try:
1018
- df[col] = pd.to_datetime(df[col], errors='ignore')
1019
- except:
1020
- pass
1021
-
1022
- # Finally clean strings
1023
- df[col] = df[col].astype(str).str.strip().replace('nan', np.nan)
1024
-
1025
- # Generate visualization request
1026
- vis_request = interpret_natural_language(prompt, df.columns.tolist())
1027
- if not vis_request:
1028
- raise HTTPException(400, "Could not interpret visualization request")
1029
-
1030
- # Create visualization
1031
- plt.style.use(style)
1032
- fig, ax = plt.subplots(figsize=(10, 6))
1033
-
1034
- try:
1035
- if vis_request.chart_type == "heatmap":
1036
- numeric_df = df.select_dtypes(include=['number'])
1037
- if numeric_df.empty:
1038
- raise ValueError("No numeric columns for heatmap")
1039
- sns.heatmap(numeric_df.corr(), annot=True, ax=ax)
1040
- else:
1041
- # Ensure numeric data for plotting
1042
- plot_data = df.copy()
1043
- if vis_request.x_column:
1044
- plot_data[vis_request.x_column] = pd.to_numeric(
1045
- plot_data[vis_request.x_column],
1046
- errors='coerce'
1047
- )
1048
- if vis_request.y_column:
1049
- plot_data[vis_request.y_column] = pd.to_numeric(
1050
- plot_data[vis_request.y_column],
1051
- errors='coerce'
1052
- )
1053
-
1054
- # Remove rows with missing numeric data
1055
- plot_data = plot_data.dropna()
1056
-
1057
- if vis_request.chart_type == "line":
1058
- sns.lineplot(
1059
- data=plot_data,
1060
- x=vis_request.x_column,
1061
- y=vis_request.y_column,
1062
- hue=vis_request.hue_column,
1063
- ax=ax
1064
- )
1065
- elif vis_request.chart_type == "bar":
1066
- sns.barplot(
1067
- data=plot_data,
1068
- x=vis_request.x_column,
1069
- y=vis_request.y_column,
1070
- hue=vis_request.hue_column,
1071
- ax=ax
1072
- )
1073
- elif vis_request.chart_type == "scatter":
1074
- sns.scatterplot(
1075
- data=plot_data,
1076
- x=vis_request.x_column,
1077
- y=vis_request.y_column,
1078
- hue=vis_request.hue_column,
1079
- ax=ax
1080
- )
1081
- # Add other chart types as needed...
1082
-
1083
- ax.set_title(vis_request.title)
1084
- buf = BytesIO()
1085
- plt.savefig(buf, format='png', bbox_inches='tight')
1086
- plt.close(fig)
1087
- buf.seek(0)
1088
-
1089
- return {
1090
- "status": "success",
1091
- "image": base64.b64encode(buf.read()).decode('utf-8'),
1092
- "chart_type": vis_request.chart_type,
1093
- "columns": list(df.columns),
1094
- "x_column": vis_request.x_column,
1095
- "y_column": vis_request.y_column,
1096
- "hue_column": vis_request.hue_column
1097
- }
1098
-
1099
- except Exception as e:
1100
- raise HTTPException(400, detail=f"Plotting error: {str(e)}")
1101
-
1102
- except HTTPException:
1103
- raise
1104
- except Exception as e:
1105
- logger.error(f"Visualization error: {str(e)}", exc_info=True)
1106
- raise HTTPException(500, detail=f"Server error: {str(e)}")
1107
-
1108
-
1109
- @app.get("/visualize/styles")
1110
- @limiter.limit("10/minute")
1111
- async def list_available_styles(request: Request):
1112
- """List all available matplotlib styles"""
1113
- return {"available_styles": plt.style.available}
1114
-
1115
- @app.post("/get_columns")
1116
- @limiter.limit("10/minute")
1117
- async def get_excel_columns(
1118
- request: Request,
1119
- file: UploadFile = File(...)
1120
- ):
1121
- try:
1122
- file_ext, content = await process_uploaded_file(file)
1123
- if file_ext not in {"xlsx", "xls"}:
1124
- raise HTTPException(400, "Only Excel files are supported")
1125
-
1126
- df = pd.read_excel(io.BytesIO(content))
1127
- return {
1128
- "columns": list(df.columns),
1129
- "sample_data": df.head().to_dict(orient='records'),
1130
- "statistics": df.describe().to_dict() if len(df.select_dtypes(include=['number']).columns) > 0 else None
1131
- }
1132
- except Exception as e:
1133
- logger.error(f"Column extraction failed: {str(e)}")
1134
- raise HTTPException(500, detail="Failed to extract columns from Excel file")
1135
-
1136
- @app.exception_handler(RateLimitExceeded)
1137
- async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
1138
- return JSONResponse(
1139
- status_code=429,
1140
- content={"detail": "Too many requests. Please try again later."}
1141
- )
1142
- import gradio as gr
1143
-
1144
- # Gradio interface for visualization
1145
- def gradio_visualize(file, prompt, style="seaborn-v0_8"):
1146
- # Call your existing FastAPI endpoint
1147
- with open(file.name, "rb") as f:
1148
- response = client.post(
1149
- "/visualize/natural",
1150
- files={"file": f},
1151
- data={"prompt": prompt, "style": style}
1152
- )
1153
- result = response.json()
1154
-
1155
- # Return both image and code
1156
- return (
1157
- result["image"], # Base64 image
1158
- f"```python\n{result['code']}\n```" # Code with Markdown formatting
1159
- )
1160
-
1161
- # Create Gradio interface
1162
- visualization_interface = gr.Interface(
1163
- fn=gradio_visualize,
1164
- inputs=[
1165
- gr.File(label="Upload Excel File", type="filepath"),
1166
- gr.Textbox(label="Visualization Prompt", placeholder="e.g., 'Show sales by region'"),
1167
- gr.Dropdown(label="Style", choices=plt.style.available, value="seaborn-v0_8")
1168
- ],
1169
- outputs=gr.Image(label="Generated Visualization"),
1170
- title="📊 Data Visualizer",
1171
- description="Upload an Excel file and describe the visualization you want"
1172
- )
1173
-
1174
-
1175
- # Mount Gradio to your FastAPI app
1176
- app = gr.mount_gradio_app(app, visualization_interface, path="/gradio")
1177
-
1178
 
1179
 
1180
  # ===== ADD THIS AT THE BOTTOM OF main.py =====
 
296
 
297
 
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
 
300
 
301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
 
 
 
 
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
 
310
 
 
509
 
510
 
511
 
512
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
 
515
  # ===== ADD THIS AT THE BOTTOM OF main.py =====