Spaces:
Running
Running
Rename main.py to app.py
Browse files- main.py → app.py +39 -8
main.py → app.py
RENAMED
@@ -4,6 +4,7 @@ from transformers import pipeline
|
|
4 |
import torch
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from typing import Dict, Any
|
|
|
7 |
|
8 |
# Inisialisasi aplikasi FastAPI
|
9 |
app = FastAPI(
|
@@ -53,13 +54,18 @@ def get_task(model_id: str) -> str:
|
|
53 |
for task, models in TASK_MAP.items():
|
54 |
if model_id in models:
|
55 |
return task
|
56 |
-
|
|
|
57 |
|
58 |
# Event startup untuk inisialisasi model
|
59 |
@app.on_event("startup")
|
60 |
async def load_models():
|
61 |
app.state.pipelines = {}
|
62 |
print("🟢 Semua model siap digunakan!")
|
|
|
|
|
|
|
|
|
63 |
|
64 |
# Endpoint utama
|
65 |
@app.get("/")
|
@@ -96,11 +102,14 @@ async def health_check():
|
|
96 |
@app.post("/inference/{model_id}")
|
97 |
async def model_inference(model_id: str, request: InferenceRequest):
|
98 |
try:
|
|
|
|
|
|
|
99 |
# Validasi model ID
|
100 |
if model_id not in MODEL_MAP:
|
101 |
raise HTTPException(
|
102 |
status_code=404,
|
103 |
-
detail=f"Model {model_id} tidak ditemukan. Cek /models untuk list model yang tersedia."
|
104 |
)
|
105 |
|
106 |
# Dapatkan task yang sesuai
|
@@ -108,11 +117,18 @@ async def model_inference(model_id: str, request: InferenceRequest):
|
|
108 |
|
109 |
# Load model jika belum ada di memory
|
110 |
if model_id not in app.state.pipelines:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
app.state.pipelines[model_id] = pipeline(
|
112 |
task=task,
|
113 |
model=MODEL_MAP[model_id],
|
114 |
-
device=
|
115 |
-
torch_dtype=
|
116 |
)
|
117 |
print(f"✅ Model {model_id} berhasil dimuat!")
|
118 |
|
@@ -128,6 +144,7 @@ async def model_inference(model_id: str, request: InferenceRequest):
|
|
128 |
)[0]['generated_text']
|
129 |
|
130 |
elif task == "text-classification":
|
|
|
131 |
output = pipe(request.text)[0]
|
132 |
result = {
|
133 |
"label": output['label'],
|
@@ -135,19 +152,33 @@ async def model_inference(model_id: str, request: InferenceRequest):
|
|
135 |
}
|
136 |
|
137 |
elif task == "text2text-generation":
|
|
|
138 |
result = pipe(
|
139 |
request.text,
|
140 |
max_length=request.max_length
|
141 |
)[0]['generated_text']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
return {"result": result}
|
144 |
|
145 |
except Exception as e:
|
|
|
|
|
|
|
|
|
|
|
146 |
raise HTTPException(
|
147 |
status_code=500,
|
148 |
-
detail=f"Error processing request: {str(e)}"
|
149 |
)
|
150 |
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
4 |
import torch
|
5 |
from fastapi.middleware.cors import CORSMiddleware
|
6 |
from typing import Dict, Any
|
7 |
+
import os # Import os module
|
8 |
|
9 |
# Inisialisasi aplikasi FastAPI
|
10 |
app = FastAPI(
|
|
|
54 |
for task, models in TASK_MAP.items():
|
55 |
if model_id in models:
|
56 |
return task
|
57 |
+
# Default to text-generation if not found (or raise an error)
|
58 |
+
return "text-generation"
|
59 |
|
60 |
# Event startup untuk inisialisasi model
|
61 |
@app.on_event("startup")
|
62 |
async def load_models():
|
63 |
app.state.pipelines = {}
|
64 |
print("🟢 Semua model siap digunakan!")
|
65 |
+
# Menyetel HF_HOME untuk mengatasi masalah izin cache
|
66 |
+
os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
|
67 |
+
os.makedirs(os.environ['HF_HOME'], exist_ok=True)
|
68 |
+
|
69 |
|
70 |
# Endpoint utama
|
71 |
@app.get("/")
|
|
|
102 |
@app.post("/inference/{model_id}")
|
103 |
async def model_inference(model_id: str, request: InferenceRequest):
|
104 |
try:
|
105 |
+
# Pastikan model_id dalam lowercase agar sesuai dengan MODEL_MAP
|
106 |
+
model_id = model_id.lower()
|
107 |
+
|
108 |
# Validasi model ID
|
109 |
if model_id not in MODEL_MAP:
|
110 |
raise HTTPException(
|
111 |
status_code=404,
|
112 |
+
detail=f"Model '{model_id}' tidak ditemukan. Cek /models untuk list model yang tersedia."
|
113 |
)
|
114 |
|
115 |
# Dapatkan task yang sesuai
|
|
|
117 |
|
118 |
# Load model jika belum ada di memory
|
119 |
if model_id not in app.state.pipelines:
|
120 |
+
print(f"⏳ Memuat model {model_id} untuk task {task}...")
|
121 |
+
# Menggunakan device=-1 (CPU) sebagai default yang aman
|
122 |
+
# Jika Anda yakin Hugging Face Space Anda memiliki GPU, gunakan device=0
|
123 |
+
device_to_use = 0 if torch.cuda.is_available() else -1
|
124 |
+
# Menyesuaikan dtype berdasarkan device
|
125 |
+
dtype_to_use = torch.float16 if torch.cuda.is_available() else torch.float32
|
126 |
+
|
127 |
app.state.pipelines[model_id] = pipeline(
|
128 |
task=task,
|
129 |
model=MODEL_MAP[model_id],
|
130 |
+
device=device_to_use,
|
131 |
+
torch_dtype=dtype_to_use
|
132 |
)
|
133 |
print(f"✅ Model {model_id} berhasil dimuat!")
|
134 |
|
|
|
144 |
)[0]['generated_text']
|
145 |
|
146 |
elif task == "text-classification":
|
147 |
+
# Untuk text-classification, output adalah list of dict, kita ambil yang pertama
|
148 |
output = pipe(request.text)[0]
|
149 |
result = {
|
150 |
"label": output['label'],
|
|
|
152 |
}
|
153 |
|
154 |
elif task == "text2text-generation":
|
155 |
+
# Untuk text2text-generation, output juga list of dict
|
156 |
result = pipe(
|
157 |
request.text,
|
158 |
max_length=request.max_length
|
159 |
)[0]['generated_text']
|
160 |
+
|
161 |
+
else:
|
162 |
+
# Fallback untuk task yang tidak terduga, meski harusnya terhandle oleh get_task
|
163 |
+
raise HTTPException(
|
164 |
+
status_code=500,
|
165 |
+
detail=f"Tugas ({task}) untuk model {model_id} tidak didukung atau tidak dikenali."
|
166 |
+
)
|
167 |
|
168 |
return {"result": result}
|
169 |
|
170 |
except Exception as e:
|
171 |
+
# Log error lebih detail untuk debugging
|
172 |
+
print(f"‼️ Error saat memproses model {model_id}: {e}")
|
173 |
+
import traceback
|
174 |
+
traceback.print_exc() # Mencetak full traceback ke log
|
175 |
+
|
176 |
raise HTTPException(
|
177 |
status_code=500,
|
178 |
+
detail=f"Error processing request: {str(e)}. Cek log server untuk detail."
|
179 |
)
|
180 |
|
181 |
+
# Ini tidak perlu dijalankan secara langsung di Hugging Face Spaces karena Uvicorn akan menjalankannya
|
182 |
+
# if __name__ == "__main__":
|
183 |
+
# import uvicorn
|
184 |
+
# uvicorn.run(app, host="0.0.0.0", port=7860)
|