SpiceyToad commited on
Commit
24242bc
·
1 Parent(s): 156851a

Fix API token and device handling

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -1,13 +1,23 @@
1
  from fastapi import FastAPI, Request
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
 
 
 
 
4
 
5
  app = FastAPI()
6
 
7
  # Load the Falcon 7B model and tokenizer
8
- MODEL_NAME = "SpiceyToad/demo-falc" # Replace with your Hub repo name
9
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto")
 
 
 
 
 
 
11
 
12
  @app.post("/generate")
13
  async def generate_text(request: Request):
@@ -17,8 +27,8 @@ async def generate_text(request: Request):
17
  max_length = data.get("max_length", 50)
18
 
19
  # Tokenize input and generate text
20
- inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
21
  outputs = model.generate(inputs["input_ids"], max_length=max_length)
22
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
23
 
24
- return {"generated_text": response}
 
1
  from fastapi import FastAPI, Request
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
+ import os
5
+
6
+ # Retrieve the Hugging Face API token from the environment
7
+ HF_API_TOKEN = os.getenv("HF_API_TOKEN")
8
 
9
  app = FastAPI()
10
 
11
  # Load the Falcon 7B model and tokenizer
12
+ MODEL_NAME = "SpiceyToad/demo-falc" # Replace with your Hugging Face repo name
13
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_API_TOKEN)
14
+ model = AutoModelForCausalLM.from_pretrained(
15
+ MODEL_NAME, torch_dtype=torch.bfloat16, device_map="auto", use_auth_token=HF_API_TOKEN
16
+ )
17
+
18
+ # Automatically determine if CUDA is available for GPU support
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ model = model.to(device)
21
 
22
  @app.post("/generate")
23
  async def generate_text(request: Request):
 
27
  max_length = data.get("max_length", 50)
28
 
29
  # Tokenize input and generate text
30
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
31
  outputs = model.generate(inputs["input_ids"], max_length=max_length)
32
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
33
 
34
+ return {"generated_text": generated_text}