yakine commited on
Commit
c90ec2a
·
verified ·
1 Parent(s): 24605c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -32
app.py CHANGED
@@ -6,9 +6,7 @@ import os
6
  import torch
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from io import StringIO
9
- from tqdm import tqdm
10
- import accelerate
11
- from accelerate import init_empty_weights, disk_offload
12
  from fastapi.middleware.cors import CORSMiddleware
13
  import re
14
 
@@ -17,7 +15,7 @@ app = FastAPI()
17
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
- allow_origins=["*"], # You can specify domains here
21
  allow_credentials=True,
22
  allow_methods=["*"],
23
  allow_headers=["*"],
@@ -36,36 +34,22 @@ model_gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
36
  # Create a pipeline for text generation using GPT-2
37
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
38
 
39
- # Load the Llama-3 model and tokenizer once during startup
 
 
 
40
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
41
  model_llama = AutoModelForCausalLM.from_pretrained(
42
  "meta-llama/Meta-Llama-3-8B",
43
  torch_dtype='auto',
44
  device_map='auto',
 
 
45
  token=hf_token
46
- )
47
 
48
  # Define your prompt template
49
- prompt_template = """\
50
- You are an expert in generating synthetic data for machine learning models.
51
- Your task is to generate a synthetic tabular dataset based on the description provided below.
52
- Description: {description}
53
- The dataset should include the following columns: {columns}
54
- Please provide the data in CSV format with a minimum of 100 rows per generation.
55
- Ensure that the data is realistic, does not contain any duplicate rows, and follows any specific conditions mentioned.
56
- Example Description:
57
- Generate a dataset for predicting house prices with columns: 'Size', 'Location', 'Number of Bedrooms', 'Price'
58
- Example Output:
59
- Size,Location,Number of Bedrooms,Price
60
- 1200,Suburban,3,250000
61
- 900,Urban,2,200000
62
- 1500,Rural,4,300000
63
- ...
64
- Description:
65
- {description}
66
- Columns:
67
- {columns}
68
- Output: """
69
 
70
  class DataGenerationRequest(BaseModel):
71
  description: str
@@ -113,7 +97,6 @@ def generate_synthetic_data(description, columns):
113
  return f"Error: {e}"
114
 
115
  def clean_generated_text(generated_text):
116
- # Extract CSV part using a regular expression
117
  csv_match = re.search(r'(\n?([A-Za-z0-9_]+,)*[A-Za-z0-9_]+\n([^\n,]*,)*[^\n,]*\n*)+', generated_text)
118
 
119
  if csv_match:
@@ -124,10 +107,8 @@ def clean_generated_text(generated_text):
124
  return csv_text
125
 
126
  def process_generated_data(csv_data):
127
- # Clean the generated data
128
  cleaned_data = clean_generated_text(csv_data)
129
 
130
- # Convert to DataFrame
131
  data = StringIO(cleaned_data)
132
  df = pd.read_csv(data)
133
 
@@ -142,12 +123,9 @@ def generate_data(request: DataGenerationRequest):
142
  if "Error" in generated_data:
143
  return JSONResponse(content={"error": generated_data}, status_code=500)
144
 
145
- # Process the generated CSV data into a DataFrame
146
  df_synthetic = process_generated_data(generated_data)
147
  return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
148
 
149
-
150
-
151
  @app.get("/")
152
  def greet_json():
153
  return {"Hello": "World!"}
 
6
  import torch
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer, AutoTokenizer, AutoModelForCausalLM, pipeline
8
  from io import StringIO
9
+ from accelerate import Accelerator
 
 
10
  from fastapi.middleware.cors import CORSMiddleware
11
  import re
12
 
 
15
 
16
  app.add_middleware(
17
  CORSMiddleware,
18
+ allow_origins=["*"],
19
  allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
 
34
  # Create a pipeline for text generation using GPT-2
35
  text_generator = pipeline("text-generation", model=model_gpt2, tokenizer=tokenizer_gpt2)
36
 
37
+ # Initialize accelerator with disk offload
38
+ accelerator = Accelerator(cpu=False, disk_offload=True)
39
+
40
+ # Load the Llama-3 model and tokenizer with disk offload
41
  tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B", token=hf_token)
42
  model_llama = AutoModelForCausalLM.from_pretrained(
43
  "meta-llama/Meta-Llama-3-8B",
44
  torch_dtype='auto',
45
  device_map='auto',
46
+ offload_folder="offload", # Folder to offload weights to disk
47
+ offload_state_dict=True, # Offload state_dict to disk
48
  token=hf_token
49
+ ).to(accelerator.device)
50
 
51
  # Define your prompt template
52
+ prompt_template = """...""" # Your existing prompt template here
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  class DataGenerationRequest(BaseModel):
55
  description: str
 
97
  return f"Error: {e}"
98
 
99
  def clean_generated_text(generated_text):
 
100
  csv_match = re.search(r'(\n?([A-Za-z0-9_]+,)*[A-Za-z0-9_]+\n([^\n,]*,)*[^\n,]*\n*)+', generated_text)
101
 
102
  if csv_match:
 
107
  return csv_text
108
 
109
  def process_generated_data(csv_data):
 
110
  cleaned_data = clean_generated_text(csv_data)
111
 
 
112
  data = StringIO(cleaned_data)
113
  df = pd.read_csv(data)
114
 
 
123
  if "Error" in generated_data:
124
  return JSONResponse(content={"error": generated_data}, status_code=500)
125
 
 
126
  df_synthetic = process_generated_data(generated_data)
127
  return JSONResponse(content={"data": df_synthetic.to_dict(orient="records")})
128
 
 
 
129
  @app.get("/")
130
  def greet_json():
131
  return {"Hello": "World!"}