VietCat commited on
Commit
4bf4a35
·
1 Parent(s): a0024c9

fix runtime error

Browse files
Files changed (3) hide show
  1. Dockerfile +17 -12
  2. app.py +56 -18
  3. requirements.txt +4 -3
Dockerfile CHANGED
@@ -1,22 +1,27 @@
 
1
  FROM python:3.10-slim
2
 
3
- # Cài đặt các thư viện hệ thống cần thiết
 
 
 
 
 
 
4
  RUN apt-get update && apt-get install -y \
5
  git \
6
  && rm -rf /var/lib/apt/lists/*
7
 
8
- # Tạo thư mục làm việc
9
- WORKDIR /app
10
-
11
- # Copy mã nguồn và cài đặt requirements
12
  COPY requirements.txt .
13
- RUN pip install --no-cache-dir -r requirements.txt
 
14
 
15
- COPY . .
 
16
 
17
- # Tạo thư mục cache có quyền ghi
18
- RUN mkdir -p /tmp/hf-cache && chmod -R 777 /tmp/hf-cache
19
- ENV TRANSFORMERS_CACHE=/tmp/hf-cache
20
 
21
- # Chạy server FastAPI
22
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"]
 
1
+ # Use lightweight Python image
2
  FROM python:3.10-slim
3
 
4
+ # Prevent interactive prompts during package install
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+
7
+ # Set working directory
8
+ WORKDIR /app
9
+
10
+ # Install basic dependencies
11
  RUN apt-get update && apt-get install -y \
12
  git \
13
  && rm -rf /var/lib/apt/lists/*
14
 
15
+ # Copy requirements and install Python packages
 
 
 
16
  COPY requirements.txt .
17
+ RUN pip install --no-cache-dir --upgrade pip \
18
+ && pip install --no-cache-dir -r requirements.txt
19
 
20
+ # Copy application code
21
+ COPY app.py .
22
 
23
+ # Expose Gradio default port
24
+ EXPOSE 7860
 
25
 
26
+ # Run the app
27
+ CMD ["python", "app.py"]
app.py CHANGED
@@ -1,26 +1,64 @@
1
- import os
2
- from fastapi import FastAPI, Request
3
- from pydantic import BaseModel
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  import torch
 
 
6
 
7
- # Đặt thư mục cache có quyền ghi
8
- os.makedirs("/tmp/hf-cache", exist_ok=True)
9
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf-cache"
10
 
11
- # Sử dụng model công khai
12
- model_name = "VietAI/vit5-base"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  app = FastAPI()
17
 
18
- class InputData(BaseModel):
19
- input: str
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- @app.post("/predict")
22
- async def predict(request: Request, data: InputData):
23
- input_ids = tokenizer.encode(data.input, return_tensors="pt", max_length=512, truncation=True)
24
- output_ids = model.generate(input_ids, max_length=128, num_beams=4, early_stopping=True)
25
- output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
26
- return {"output": output}
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
3
  import torch
4
+ from fastapi import FastAPI, Request
5
+ import uvicorn
6
 
7
+ from threading import Thread
 
 
8
 
9
+ # -------- Load model --------
10
+ model_name = "VietAI/gpt-neo-1.3B-vietnamese-news"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name)
13
+ model.eval()
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ model.to(device)
16
+
17
+ # -------- Inference function --------
18
+ def generate_text(prompt, max_tokens=100, temperature=0.9, top_p=0.95):
19
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
20
+ outputs = model.generate(
21
+ **inputs,
22
+ max_new_tokens=max_tokens,
23
+ do_sample=True,
24
+ temperature=temperature,
25
+ top_p=top_p,
26
+ pad_token_id=tokenizer.eos_token_id,
27
+ )
28
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
29
 
30
+ # -------- Gradio UI --------
31
+ def launch_gradio():
32
+ with gr.Blocks() as demo:
33
+ gr.Markdown("## 🇻🇳 VietAI GPT-Neo 1.3B - Sinh văn bản tiếng Việt")
34
+ prompt = gr.Textbox(label="Prompt", placeholder="Nhập đoạn mở đầu văn bản...")
35
+ max_tokens = gr.Slider(10, 200, value=100, label="Số tokens sinh ra")
36
+ temperature = gr.Slider(0.1, 1.5, value=0.9, label="Temperature")
37
+ top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p sampling")
38
+ output = gr.Textbox(label="Kết quả", lines=10)
39
+ btn = gr.Button("Sinh văn bản")
40
+ btn.click(fn=generate_text, inputs=[prompt, max_tokens, temperature, top_p], outputs=output)
41
+
42
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
43
+
44
+ # -------- FastAPI for REST API --------
45
  app = FastAPI()
46
 
47
+ @app.post("/generate")
48
+ async def generate(request: Request):
49
+ body = await request.json()
50
+ prompt = body.get("prompt", "")
51
+ max_tokens = body.get("max_tokens", 100)
52
+ temperature = body.get("temperature", 0.9)
53
+ top_p = body.get("top_p", 0.95)
54
+ output = generate_text(prompt, max_tokens, temperature, top_p)
55
+ return {"response": output}
56
+
57
+ # -------- Start Gradio in background --------
58
+ if __name__ == "__main__":
59
+ # Run Gradio in another thread
60
+ thread = Thread(target=launch_gradio)
61
+ thread.start()
62
 
63
+ # Start FastAPI
64
+ uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- fastapi
 
 
2
  uvicorn
3
- transformers==4.40.0
4
- torch>=1.13.1
 
1
+ transformers>=4.40.0
2
+ torch
3
+ gradio>=4.26.0
4
  uvicorn
5
+ fastapi