Update app.py
Browse files
app.py
CHANGED
@@ -4,13 +4,41 @@ import torch
|
|
4 |
|
5 |
app = FastAPI()
|
6 |
|
|
|
7 |
MODEL_MAP = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
9 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
}
|
11 |
|
|
|
12 |
MODEL_CACHE = {}
|
13 |
|
|
|
14 |
def load_model(model_id):
|
15 |
if model_id not in MODEL_CACHE:
|
16 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
@@ -18,6 +46,7 @@ def load_model(model_id):
|
|
18 |
MODEL_CACHE[model_id] = (tokenizer, model)
|
19 |
return MODEL_CACHE[model_id]
|
20 |
|
|
|
21 |
@app.post("/translate")
|
22 |
async def translate(request: Request):
|
23 |
data = await request.json()
|
@@ -25,18 +54,34 @@ async def translate(request: Request):
|
|
25 |
target_lang = data.get("target_lang")
|
26 |
|
27 |
if not text or not target_lang:
|
28 |
-
return {"error": "Missing text or target_lang"}
|
29 |
|
30 |
model_id = MODEL_MAP.get(target_lang)
|
31 |
if not model_id:
|
32 |
-
return {"error": f"No model for '{target_lang}'"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
return {"
|
38 |
|
39 |
-
#
|
40 |
import uvicorn
|
41 |
if __name__ == "__main__":
|
42 |
-
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
|
|
4 |
|
5 |
app = FastAPI()
|
6 |
|
7 |
+
# Map target languages to Hugging Face model IDs
|
8 |
MODEL_MAP = {
|
9 |
+
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
|
10 |
+
"cs": "Helsinki-NLP/opus-mt-en-cs",
|
11 |
+
"da": "Helsinki-NLP/opus-mt-en-da",
|
12 |
+
"de": "Helsinki-NLP/opus-mt-en-de",
|
13 |
+
"el": "Helsinki-NLP/opus-mt-tc-big-en-el",
|
14 |
+
"es": "facebook/nllb-200-distilled-600M",
|
15 |
+
"et": "Helsinki-NLP/opus-mt-tc-big-en-et",
|
16 |
+
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
|
17 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
18 |
+
"hr": "facebook/mbart-large-50-many-to-many-mmt",
|
19 |
+
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
|
20 |
+
"is": "facebook/nllb-200-distilled-600M",
|
21 |
+
"it": "facebook/nllb-200-distilled-600M",
|
22 |
+
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
|
23 |
+
"lv": "facebook/mbart-large-50-many-to-many-mmt",
|
24 |
+
"mk": "facebook/nllb-200-distilled-600M",
|
25 |
+
"nb": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
26 |
+
"nl": "facebook/mbart-large-50-many-to-many-mmt",
|
27 |
+
"no": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
28 |
+
"pl": "facebook/nllb-200-distilled-600M",
|
29 |
+
"pt": "facebook/mbart-large-50-many-to-many-mmt",
|
30 |
+
"ro": "facebook/mbart-large-50-many-to-many-mmt",
|
31 |
+
"sk": "Helsinki-NLP/opus-mt-en-sk",
|
32 |
+
"sl": "alirezamsh/small100",
|
33 |
+
"sq": "alirezamsh/small100",
|
34 |
+
"sv": "Helsinki-NLP/opus-mt-en-sv",
|
35 |
+
"tr": "facebook/nllb-200-distilled-600M"
|
36 |
}
|
37 |
|
38 |
+
|
39 |
MODEL_CACHE = {}
|
40 |
|
41 |
+
# ✅ Load Hugging Face model (Helsinki or Small100)
|
42 |
def load_model(model_id):
|
43 |
if model_id not in MODEL_CACHE:
|
44 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
|
|
46 |
MODEL_CACHE[model_id] = (tokenizer, model)
|
47 |
return MODEL_CACHE[model_id]
|
48 |
|
49 |
+
# ✅ POST /translate
|
50 |
@app.post("/translate")
|
51 |
async def translate(request: Request):
|
52 |
data = await request.json()
|
|
|
54 |
target_lang = data.get("target_lang")
|
55 |
|
56 |
if not text or not target_lang:
|
57 |
+
return {"error": "Missing 'text' or 'target_lang'"}
|
58 |
|
59 |
model_id = MODEL_MAP.get(target_lang)
|
60 |
if not model_id:
|
61 |
+
return {"error": f"No model found for target language '{target_lang}'"}
|
62 |
+
|
63 |
+
if model_id.startswith("facebook/"):
|
64 |
+
return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
|
65 |
+
|
66 |
+
try:
|
67 |
+
tokenizer, model = load_model(model_id)
|
68 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
|
69 |
+
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
|
70 |
+
return {"translation": tokenizer.decode(outputs[0], skip_special_tokens=True)}
|
71 |
+
except Exception as e:
|
72 |
+
return {"error": f"Translation failed: {str(e)}"}
|
73 |
+
|
74 |
+
# ✅ GET /languages
|
75 |
+
@app.get("/languages")
|
76 |
+
def list_languages():
|
77 |
+
return {"supported_languages": list(MODEL_MAP.keys())}
|
78 |
|
79 |
+
# ✅ GET /health
|
80 |
+
@app.get("/health")
|
81 |
+
def health():
|
82 |
+
return {"status": "ok"}
|
83 |
|
84 |
+
# ✅ Uvicorn startup (required by Hugging Face)
|
85 |
import uvicorn
|
86 |
if __name__ == "__main__":
|
87 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|