correct issue with small100
Browse files
app.py
CHANGED
@@ -3,7 +3,9 @@ from transformers import (
|
|
3 |
MarianMTModel,
|
4 |
MarianTokenizer,
|
5 |
MBartForConditionalGeneration,
|
6 |
-
MBart50TokenizerFast
|
|
|
|
|
7 |
)
|
8 |
import torch
|
9 |
|
@@ -25,16 +27,16 @@ MODEL_MAP = {
|
|
25 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
26 |
"hr": "facebook/mbart-large-50-many-to-many-mmt",
|
27 |
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
|
28 |
-
"is": "mkorada/opus-mt-en-is-finetuned-v4",
|
29 |
"it": "Helsinki-NLP/opus-mt-tc-big-en-it",
|
30 |
-
"lb": "alirezamsh/small100",
|
31 |
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
|
32 |
"lv": "facebook/mbart-large-50-many-to-many-mmt",
|
33 |
-
"me": "Helsinki-NLP/opus-mt-tc-base-en-sh"
|
34 |
"mk": "Helsinki-NLP/opus-mt-en-mk",
|
35 |
-
"nb": "facebook/mbart-large-50-many-to-many-mmt",
|
36 |
"nl": "facebook/mbart-large-50-many-to-many-mmt",
|
37 |
-
"no": "Confused404/eng-gmq-finetuned_v2-no", #Alex's fine-tuned model
|
38 |
"pl": "Helsinki-NLP/opus-mt-en-sla",
|
39 |
"pt": "facebook/mbart-large-50-many-to-many-mmt",
|
40 |
"ro": "facebook/mbart-large-50-many-to-many-mmt",
|
@@ -45,29 +47,32 @@ MODEL_MAP = {
|
|
45 |
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
|
46 |
}
|
47 |
|
48 |
-
|
49 |
MODEL_CACHE = {}
|
50 |
|
51 |
-
# ✅ Load Hugging Face model (Helsinki or Small100)
|
52 |
def load_model(model_id: str):
|
53 |
"""
|
54 |
-
Load & cache
|
55 |
-
-
|
56 |
-
-
|
|
|
57 |
"""
|
58 |
if model_id not in MODEL_CACHE:
|
59 |
if model_id.startswith("facebook/mbart"):
|
60 |
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
|
61 |
model = MBartForConditionalGeneration.from_pretrained(model_id)
|
|
|
|
|
|
|
62 |
else:
|
63 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
64 |
model = MarianMTModel.from_pretrained(model_id)
|
|
|
65 |
model.to("cpu")
|
66 |
MODEL_CACHE[model_id] = (tokenizer, model)
|
67 |
-
return MODEL_CACHE[model_id]
|
68 |
|
|
|
69 |
|
70 |
-
# ✅ POST /translate
|
71 |
@app.post("/translate")
|
72 |
async def translate(request: Request):
|
73 |
payload = await request.json()
|
@@ -100,18 +105,15 @@ async def translate(request: Request):
|
|
100 |
except Exception as e:
|
101 |
return {"error": f"Translation failed: {e}"}
|
102 |
|
103 |
-
|
104 |
-
# ✅ GET /languages
|
105 |
@app.get("/languages")
|
106 |
def list_languages():
|
107 |
return {"supported_languages": list(MODEL_MAP.keys())}
|
108 |
|
109 |
-
# ✅ GET /health
|
110 |
@app.get("/health")
|
111 |
def health():
|
112 |
return {"status": "ok"}
|
113 |
|
114 |
-
#
|
115 |
-
import uvicorn
|
116 |
if __name__ == "__main__":
|
117 |
-
uvicorn
|
|
|
|
3 |
MarianMTModel,
|
4 |
MarianTokenizer,
|
5 |
MBartForConditionalGeneration,
|
6 |
+
MBart50TokenizerFast,
|
7 |
+
AutoTokenizer,
|
8 |
+
AutoModelForSeq2SeqLM
|
9 |
)
|
10 |
import torch
|
11 |
|
|
|
27 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
28 |
"hr": "facebook/mbart-large-50-many-to-many-mmt",
|
29 |
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
|
30 |
+
"is": "mkorada/opus-mt-en-is-finetuned-v4", # Manas's fine-tuned model
|
31 |
"it": "Helsinki-NLP/opus-mt-tc-big-en-it",
|
32 |
+
"lb": "alirezamsh/small100", # small100
|
33 |
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
|
34 |
"lv": "facebook/mbart-large-50-many-to-many-mmt",
|
35 |
+
"me": "Helsinki-NLP/opus-mt-tc-base-en-sh",
|
36 |
"mk": "Helsinki-NLP/opus-mt-en-mk",
|
37 |
+
"nb": "facebook/mbart-large-50-many-to-many-mmt",
|
38 |
"nl": "facebook/mbart-large-50-many-to-many-mmt",
|
39 |
+
"no": "Confused404/eng-gmq-finetuned_v2-no", # Alex's fine-tuned model
|
40 |
"pl": "Helsinki-NLP/opus-mt-en-sla",
|
41 |
"pt": "facebook/mbart-large-50-many-to-many-mmt",
|
42 |
"ro": "facebook/mbart-large-50-many-to-many-mmt",
|
|
|
47 |
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
|
48 |
}
|
49 |
|
50 |
+
# Cache loaded models/tokenizers
|
51 |
MODEL_CACHE = {}
|
52 |
|
|
|
53 |
def load_model(model_id: str):
|
54 |
"""
|
55 |
+
Load & cache:
|
56 |
+
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
|
57 |
+
- alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM
|
58 |
+
- all others via MarianTokenizer & MarianMTModel
|
59 |
"""
|
60 |
if model_id not in MODEL_CACHE:
|
61 |
if model_id.startswith("facebook/mbart"):
|
62 |
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
|
63 |
model = MBartForConditionalGeneration.from_pretrained(model_id)
|
64 |
+
elif model_id == "alirezamsh/small100":
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
66 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
67 |
else:
|
68 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
69 |
model = MarianMTModel.from_pretrained(model_id)
|
70 |
+
|
71 |
model.to("cpu")
|
72 |
MODEL_CACHE[model_id] = (tokenizer, model)
|
|
|
73 |
|
74 |
+
return MODEL_CACHE[model_id]
|
75 |
|
|
|
76 |
@app.post("/translate")
|
77 |
async def translate(request: Request):
|
78 |
payload = await request.json()
|
|
|
105 |
except Exception as e:
|
106 |
return {"error": f"Translation failed: {e}"}
|
107 |
|
|
|
|
|
108 |
@app.get("/languages")
|
109 |
def list_languages():
|
110 |
return {"supported_languages": list(MODEL_MAP.keys())}
|
111 |
|
|
|
112 |
@app.get("/health")
|
113 |
def health():
|
114 |
return {"status": "ok"}
|
115 |
|
116 |
+
# Uvicorn startup for local testing
|
|
|
117 |
if __name__ == "__main__":
|
118 |
+
import uvicorn
|
119 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|