from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from pydantic import BaseModel import pandas as pd import os import requests from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, pipeline from io import StringIO from fastapi.middleware.cors import CORSMiddleware from huggingface_hub import HfFolder from tqdm import tqdm app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], # You can specify domains here allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Access the Hugging Face API token from environment variables hf_token = os.getenv('HF_API_TOKEN') if not hf_token: raise ValueError("Hugging Face API token is not set. Please set the HF_API_TOKEN environment variable.") # Load GPT-2 model and tokenizer tokenizer_gpt2 = GPT2Tokenizer.from_pretrained('gpt2') model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2') # Create a pipeline for text generation using GPT-2 text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2) def preprocess_user_prompt(user_prompt): # Generate a structured prompt based on the user input generated_text = text_generator(user_prompt, max_length=50, num_return_sequences=1)[0]["generated_text"] return generated_text # Define prompt template prompt_template = """\ You are an expert in generating synthetic data for machine learning models. Your task is to generate a synthetic tabular dataset based on the description provided below. Description: {description} The dataset should include the following columns: {columns} Please provide the data in CSV format. Example Description: Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price' Example Output: Size,Location,Number of Bedrooms,Price 1200,Suburban,3,250000 900,Urban,2,200000 1500,Rural,4,300000 ... Description: {description} Columns: {columns} Output: """ tokenizer_mixtral = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token) def format_prompt(description, columns): processed_description = preprocess_user_prompt(description) prompt = prompt_template.format(description=processed_description, columns=",".join(columns)) return prompt API_URL = "https://api-inference.huggingface.co/models/mistralai/Mixtral-8x7B-Instruct-v0.1" generation_params = { "top_p": 0.90, "temperature": 0.8, "max_new_tokens": 512, "return_full_text": False, "use_cache": False } def generate_synthetic_data(description, columns): formatted_prompt = format_prompt(description, columns) payload = {"inputs": formatted_prompt, "parameters": generation_params} try: response = requests.post(API_URL, headers={"Authorization": f"Bearer {hf_token}"}, json=payload) response.raise_for_status() data = response.json() if 'generated_text' in data[0]: return data[0]['generated_text'] else: raise ValueError("Invalid response format from Hugging Face API.") except (requests.RequestException, ValueError) as e: print(f"Error during API request or response processing: {e}") return "" def process_generated_data(csv_data, expected_columns): try: # Ensure the data is cleaned and correctly formatted cleaned_data = csv_data.replace('\r\n', '\n').replace('\r', '\n') data = StringIO(cleaned_data) # Read the CSV data df = pd.read_csv(data, delimiter=',') # Check if the DataFrame has the expected columns if set(df.columns) != set(expected_columns): print(f"Unexpected columns in the generated data: {df.columns}") return None return df except pd.errors.ParserError as e: print(f"Failed to parse CSV data: {e}") return None def generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100): data_frames = [] for _ in tqdm(range(num_rows // rows_per_generation), desc="Generating Data"): generated_data = generate_synthetic_data(description, columns) if generated_data: df_synthetic = process_generated_data(generated_data, columns) if df_synthetic is not None and not df_synthetic.empty: data_frames.append(df_synthetic) else: print("Skipping invalid generation.") else: print("Skipping empty or invalid generation.") if data_frames: return pd.concat(data_frames, ignore_index=True) else: print("No valid data frames to concatenate.") return pd.DataFrame(columns=columns) class DataGenerationRequest(BaseModel): description: str columns: list[str] @app.post("/generate/") def generate_data(request: DataGenerationRequest): description = request.description.strip() columns = [col.strip() for col in request.columns] csv_data = generate_large_synthetic_data(description, columns, num_rows=1000, rows_per_generation=100) if csv_data.empty: return JSONResponse(content={"error": "No valid data generated"}, status_code=500) # Convert the DataFrame to CSV format csv_buffer = StringIO() csv_data.to_csv(csv_buffer, index=False) csv_buffer.seek(0) # Return the CSV data as a downloadable file return StreamingResponse( csv_buffer, media_type="text/csv", headers={"Content-Disposition": "attachment; filename=generated_data.csv"} ) @app.get("/") def greet_json(): return {"Hello": "World!"}