VietCat commited on
Commit
6a61342
·
1 Parent(s): 3f4ce15

fix runtime error

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -1,20 +1,22 @@
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
- # Khởi tạo app FastAPI
6
- app = FastAPI()
7
 
8
- # Load model và tokenizer
9
- model_name = "bmd1905/BARTpho2-ViT5-text2text"
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
12
 
13
- # Định nghĩa schema cho request
 
 
14
  class InputText(BaseModel):
15
  text: str
16
 
17
- # Endpoint POST /generate
18
  @app.post("/generate")
19
  def generate_text(item: InputText):
20
  inputs = tokenizer(item.text, return_tensors="pt", truncation=True, max_length=512)
 
1
+ import os
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ # Fix lỗi ghi cache
7
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
8
 
9
+ # Load model
10
+ model_name = "bmd1905/BARTpho2-ViT5-question-answering" # model này public
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
13
 
14
+ # FastAPI app
15
+ app = FastAPI()
16
+
17
  class InputText(BaseModel):
18
  text: str
19
 
 
20
  @app.post("/generate")
21
  def generate_text(item: InputText):
22
  inputs = tokenizer(item.text, return_tensors="pt", truncation=True, max_length=512)