File size: 4,575 Bytes
efb198f 447b422 b8c0a2d 447b422 efb198f 447b422 07ea3d5 efb198f 6df8ecd efb198f 4ea12a6 efb198f b8c0a2d efb198f 447b422 b8c0a2d 447b422 efb198f 447b422 b8c0a2d 447b422 b8c0a2d 447b422 efb198f b8c0a2d 447b422 efb198f 447b422 efb198f 6df8ecd efb198f 6df8ecd 447b422 07ea3d5 447b422 07ea3d5 6df8ecd 07ea3d5 447b422 07ea3d5 447b422 07ea3d5 447b422 07ea3d5 6df8ecd 447b422 6df8ecd efb198f 6df8ecd efb198f b8c0a2d efb198f b8c0a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from fastapi import FastAPI, Request
from transformers import (
MarianMTModel,
MarianTokenizer,
MBartForConditionalGeneration,
MBart50TokenizerFast,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
import torch
# import your chunking helpers
from chunking import get_max_word_length, chunk_text
app = FastAPI()
# Map target languages to Hugging Face model IDs
MODEL_MAP = {
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg", # bulgarian
"cs": "Helsinki-NLP/opus-mt-en-cs", # czech
"da": "Helsinki-NLP/opus-mt-en-da", # danish
"de": "Helsinki-NLP/opus-mt-en-de", # german
"el": "Helsinki-NLP/opus-mt-tc-big-en-el", # greek
"es": "Helsinki-NLP/opus-mt-tc-big-en-es", # spanish
"et": "Helsinki-NLP/opus-mt-tc-big-en-et", # estonian
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi", # finnish
"fr": "Helsinki-NLP/opus-mt-en-fr", # french
"hr": "facebook/mbart-large-50-many-to-many-mmt", # croatian
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu", # hungarian
"is": "mkorada/opus-mt-en-is-finetuned-v4", # icelandic # Manas's fine-tuned model
"it": "Helsinki-NLP/opus-mt-tc-big-en-it", # italian
"lb": "alirezamsh/small100", # luxembourgish # small100
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt", # lithuanian
"lv": "facebook/mbart-large-50-many-to-many-mmt", # latvian
"me": "Helsinki-NLP/opus-mt-tc-base-en-sh", # montegrin
"mk": "Helsinki-NLP/opus-mt-en-mk", # macedonian
# "nb": "facebook/mbart-large-50-many-to-many-mmt", # norwegian
"nl": "facebook/mbart-large-50-many-to-many-mmt", # dutch
"no": "Confused404/eng-gmq-finetuned_v2-no", # norwegian # Alex's fine-tuned model
"pl": "Helsinki-NLP/opus-mt-en-sla", # polish
"pt": "facebook/mbart-large-50-many-to-many-mmt", # portuguese
"ro": "facebook/mbart-large-50-many-to-many-mmt", # romanian
"sk": "Helsinki-NLP/opus-mt-en-sk", # slovak
"sl": "alirezamsh/small100", # slovene
"sq": "alirezamsh/small100", # albanian
"sv": "Helsinki-NLP/opus-mt-en-sv", # swedish
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr" # turkish
}
# Cache loaded models/tokenizers
MODEL_CACHE = {}
def load_model(model_id: str):
"""
Load & cache:
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
- alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM
- all others via MarianTokenizer & MarianMTModel
"""
if model_id not in MODEL_CACHE:
if model_id.startswith("facebook/mbart"):
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
model = MBartForConditionalGeneration.from_pretrained(model_id)
elif model_id == "alirezamsh/small100":
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
else:
tokenizer = MarianTokenizer.from_pretrained(model_id)
model = MarianMTModel.from_pretrained(model_id)
model.to("cpu")
MODEL_CACHE[model_id] = (tokenizer, model)
return MODEL_CACHE[model_id]
@app.post("/translate")
async def translate(request: Request):
payload = await request.json()
text = payload.get("text")
target_lang = payload.get("target_lang")
if not text or not target_lang:
return {"error": "Missing 'text' or 'target_lang'"}
model_id = MODEL_MAP.get(target_lang)
if not model_id:
return {"error": f"No model found for target language '{target_lang}'"}
try:
# chunk to safe length
safe_limit = get_max_word_length([target_lang])
chunks = chunk_text(text, safe_limit)
tokenizer, model = load_model(model_id)
full_translation = []
for chunk in chunks:
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return {"translation": " ".join(full_translation)}
except Exception as e:
return {"error": f"Translation failed: {e}"}
@app.get("/languages")
def list_languages():
return {"supported_languages": list(MODEL_MAP.keys())}
@app.get("/health")
def health():
return {"status": "ok"}
# Uvicorn startup for local testing
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|