|
from fastapi import FastAPI, Request |
|
from transformers import ( |
|
MarianMTModel, |
|
MarianTokenizer, |
|
MBartForConditionalGeneration, |
|
MBart50TokenizerFast, |
|
AutoTokenizer, |
|
AutoModelForSeq2SeqLM |
|
) |
|
import torch |
|
|
|
|
|
from chunking import get_max_word_length, chunk_text |
|
|
|
app = FastAPI() |
|
|
|
|
|
MODEL_MAP = { |
|
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg", |
|
"cs": "Helsinki-NLP/opus-mt-en-cs", |
|
"da": "Helsinki-NLP/opus-mt-en-da", |
|
"de": "Helsinki-NLP/opus-mt-en-de", |
|
"el": "Helsinki-NLP/opus-mt-tc-big-en-el", |
|
"es": "Helsinki-NLP/opus-mt-tc-big-en-es", |
|
"et": "Helsinki-NLP/opus-mt-tc-big-en-et", |
|
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi", |
|
"fr": "Helsinki-NLP/opus-mt-en-fr", |
|
"hr": "facebook/mbart-large-50-many-to-many-mmt", |
|
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu", |
|
"is": "mkorada/opus-mt-en-is-finetuned-v4", |
|
"it": "Helsinki-NLP/opus-mt-tc-big-en-it", |
|
"lb": "alirezamsh/small100", |
|
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt", |
|
"lv": "facebook/mbart-large-50-many-to-many-mmt", |
|
"me": "Helsinki-NLP/opus-mt-tc-base-en-sh", |
|
"mk": "Helsinki-NLP/opus-mt-en-mk", |
|
|
|
"nl": "facebook/mbart-large-50-many-to-many-mmt", |
|
"no": "Confused404/eng-gmq-finetuned_v2-no", |
|
"pl": "Helsinki-NLP/opus-mt-en-sla", |
|
"pt": "facebook/mbart-large-50-many-to-many-mmt", |
|
"ro": "facebook/mbart-large-50-many-to-many-mmt", |
|
"sk": "Helsinki-NLP/opus-mt-en-sk", |
|
"sl": "alirezamsh/small100", |
|
"sq": "alirezamsh/small100", |
|
"sv": "Helsinki-NLP/opus-mt-en-sv", |
|
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr" |
|
} |
|
|
|
|
|
MODEL_CACHE = {} |
|
|
|
def load_model(model_id: str): |
|
""" |
|
Load & cache: |
|
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration |
|
- alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM |
|
- all others via MarianTokenizer & MarianMTModel |
|
""" |
|
if model_id not in MODEL_CACHE: |
|
if model_id.startswith("facebook/mbart"): |
|
tokenizer = MBart50TokenizerFast.from_pretrained(model_id) |
|
model = MBartForConditionalGeneration.from_pretrained(model_id) |
|
elif model_id == "alirezamsh/small100": |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_id) |
|
else: |
|
tokenizer = MarianTokenizer.from_pretrained(model_id) |
|
model = MarianMTModel.from_pretrained(model_id) |
|
|
|
model.to("cpu") |
|
MODEL_CACHE[model_id] = (tokenizer, model) |
|
|
|
return MODEL_CACHE[model_id] |
|
|
|
@app.post("/translate") |
|
async def translate(request: Request): |
|
payload = await request.json() |
|
text = payload.get("text") |
|
target_lang = payload.get("target_lang") |
|
|
|
if not text or not target_lang: |
|
return {"error": "Missing 'text' or 'target_lang'"} |
|
|
|
model_id = MODEL_MAP.get(target_lang) |
|
if not model_id: |
|
return {"error": f"No model found for target language '{target_lang}'"} |
|
|
|
try: |
|
|
|
safe_limit = get_max_word_length([target_lang]) |
|
chunks = chunk_text(text, safe_limit) |
|
|
|
tokenizer, model = load_model(model_id) |
|
full_translation = [] |
|
|
|
for chunk in chunks: |
|
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True) |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True) |
|
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
|
|
|
return {"translation": " ".join(full_translation)} |
|
|
|
except Exception as e: |
|
return {"error": f"Translation failed: {e}"} |
|
|
|
@app.get("/languages") |
|
def list_languages(): |
|
return {"supported_languages": list(MODEL_MAP.keys())} |
|
|
|
@app.get("/health") |
|
def health(): |
|
return {"status": "ok"} |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860) |
|
|