SpiceyToad commited on
Commit
184700f
·
1 Parent(s): 24242bc
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -3,32 +3,24 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  import os
5
 
6
- # Retrieve the Hugging Face API token from the environment
7
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
8
 
9
  app = FastAPI()
10
 
11
- # Load the Falcon 7B model and tokenizer
12
- MODEL_NAME = "SpiceyToad/demo-falc" # Replace with your Hugging Face repo name
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_API_TOKEN)
14
  model = AutoModelForCausalLM.from_pretrained(
15
- MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto", use_auth_token=HF_API_TOKEN
16
  )
17
 
18
- # Automatically determine if CUDA is available for GPU support
19
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- model = model.to(device)
21
-
22
  @app.post("/generate")
23
  async def generate_text(request: Request):
24
- # Parse input JSON
25
  data = await request.json()
26
  prompt = data.get("prompt", "")
27
  max_length = data.get("max_length", 50)
28
 
29
- # Tokenize input and generate text
30
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
31
  outputs = model.generate(inputs["input_ids"], max_length=max_length)
32
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
-
34
- return {"generated_text": generated_text}
 
3
  import torch
4
  import os
5
 
6
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Hugging Face API token
 
7
 
8
  app = FastAPI()
9
 
10
+ # Load Falcon 7B
11
+ MODEL_NAME = "SpiceyToad/demo-falc" # Replace with your model
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_API_TOKEN)
13
  model = AutoModelForCausalLM.from_pretrained(
14
+ MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16, token=HF_API_TOKEN
15
  )
16
 
 
 
 
 
17
  @app.post("/generate")
18
  async def generate_text(request: Request):
 
19
  data = await request.json()
20
  prompt = data.get("prompt", "")
21
  max_length = data.get("max_length", 50)
22
 
23
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
24
  outputs = model.generate(inputs["input_ids"], max_length=max_length)
25
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
+ return {"generated_text": response}