dat257 commited on
Commit
091451d
·
verified ·
1 Parent(s): 9212344

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -15
app.py CHANGED
@@ -1,26 +1,78 @@
 
1
  import gradio as gr
2
  from transformers import pipeline
 
3
  import torch
4
- import logging
 
 
 
 
 
5
 
6
- logging.basicConfig(level=logging.INFO)
 
 
 
 
 
 
7
 
8
- qa_pipeline = pipeline(
9
- "question-answering",
10
- model="nguyenvulebinh/vi-mrc-base",
11
- device=0 if torch.cuda.is_available() else -1
12
- )
 
 
 
 
 
 
 
13
 
14
- def answer_fn(question, context):
 
15
  result = qa_pipeline(question=question, context=context)
16
  return result["answer"]
17
 
18
- iface = gr.Interface(
19
- fn=answer_fn,
20
- inputs=["text", "text"],
21
- outputs="text",
22
- title="AgriBot: Hỏi đáp nông nghiệp"
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
25
  if __name__ == "__main__":
26
- iface.launch()
 
 
 
 
 
 
1
+ import logging
2
  import gradio as gr
3
  from transformers import pipeline
4
+ import os
5
  import torch
6
+ from huggingface_hub import login
7
+ from fastapi import FastAPI, Request
8
+ from fastapi.responses import JSONResponse
9
+
10
+ # Cấu hình logging
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
 
13
+ # Đăng nhập Hugging Face
14
+ try:
15
+ login(token=os.getenv("HF_TOKEN"))
16
+ logging.info("Logged in to Hugging Face Hub successfully")
17
+ except Exception as e:
18
+ logging.error(f"Failed to login to Hugging Face Hub: {e}")
19
+ raise
20
 
21
+ # Load mô hình
22
+ logging.info("Loading nguyenvulebinh/vi-mrc-base...")
23
+ try:
24
+ qa_pipeline = pipeline(
25
+ "question-answering",
26
+ model="nguyenvulebinh/vi-mrc-base",
27
+ device=0 if torch.cuda.is_available() else -1
28
+ )
29
+ logging.info("Model loaded successfully")
30
+ except Exception as e:
31
+ logging.error(f"Failed to load model: {e}")
32
+ raise
33
 
34
+ # Hàm xử lý cho Gradio và API
35
+ def gradio_answer(question, context):
36
  result = qa_pipeline(question=question, context=context)
37
  return result["answer"]
38
 
39
+ # Tạo FastAPI app để thêm endpoint API
40
+ app = FastAPI()
41
+
42
+ @app.post("/api/answer")
43
+ async def api_answer(request: Request):
44
+ try:
45
+ data = await request.json()
46
+ question = data.get("question")
47
+ context = data.get("context")
48
+ logging.info(f"Received request - Question: {question}, Context: {context[:200]}...")
49
+ if not question or not context:
50
+ logging.error("Missing question or context")
51
+ return JSONResponse({"error": "Missing question or context"}, status_code=400)
52
+ result = qa_pipeline(question=question, context=context)
53
+ logging.info(f"Response - Answer: {result['answer']}")
54
+ return JSONResponse({"answer": result["answer"]})
55
+ except Exception as e:
56
+ logging.error(f"API error: {e}")
57
+ return JSONResponse({"error": str(e)}, status_code=500)
58
+
59
+ # Tạo Gradio Blocks
60
+ with gr.Blocks() as demo:
61
+ gr.Markdown("# AgriBot: Hỏi đáp nông nghiệp")
62
+ gr.Markdown("Nhập câu hỏi và ngữ cảnh để nhận câu trả lời về nông nghiệp.")
63
+
64
+ with gr.Row():
65
+ question_input = gr.Textbox(label="Câu hỏi", placeholder="Nhập câu hỏi của bạn...")
66
+ context_input = gr.Textbox(label="Ngữ cảnh", placeholder="Nhập ngữ cảnh liên quan...")
67
+ output = gr.Textbox(label="Câu trả lời")
68
+ submit_btn = gr.Button("Gửi")
69
+ submit_btn.click(fn=gradio_answer, inputs=[question_input, context_input], outputs=output)
70
 
71
+ # Chạy ứng dụng
72
  if __name__ == "__main__":
73
+ logging.info("Starting Gradio on port 7860...")
74
+ demo.launch(
75
+ server_name="0.0.0.0",
76
+ server_port=7860,
77
+ inline=False,
78
+ )