Confused404's picture
added a prepend >>nob<< for Norwegian right before tokenization
a40fe7a verified
raw
history blame
4.85 kB
from fastapi import FastAPI, Request
from transformers import (
MarianMTModel,
MarianTokenizer,
MBartForConditionalGeneration,
MBart50TokenizerFast,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
import torch
from tokenization_small100 import SMALL100Tokenizer
# import your chunking helpers
from chunking import get_max_word_length, chunk_text
app = FastAPI()
# Map target languages to Hugging Face model IDs
MODEL_MAP = {
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg", # bulgarian
"cs": "Helsinki-NLP/opus-mt-en-cs", # czech
"da": "Helsinki-NLP/opus-mt-en-da", # danish
"de": "Helsinki-NLP/opus-mt-en-de", # german
"el": "Helsinki-NLP/opus-mt-tc-big-en-el", # greek
"es": "Helsinki-NLP/opus-mt-tc-big-en-es", # spanish
"et": "Helsinki-NLP/opus-mt-tc-big-en-et", # estonian
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi", # finnish
"fr": "Helsinki-NLP/opus-mt-en-fr", # french
"hr": "facebook/mbart-large-50-many-to-many-mmt", # croatian
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu", # hungarian
"is": "mkorada/opus-mt-en-is-finetuned-v4", # icelandic # Manas's fine-tuned model
"it": "Helsinki-NLP/opus-mt-tc-big-en-it", # italian
"lb": "alirezamsh/small100", # luxembourgish # small100
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt", # lithuanian
"lv": "facebook/mbart-large-50-many-to-many-mmt", # latvian
"me": "Helsinki-NLP/opus-mt-tc-base-en-sh", # montegrin
"mk": "Helsinki-NLP/opus-mt-en-mk", # macedonian
# "nb": "facebook/mbart-large-50-many-to-many-mmt", # norwegian
"nl": "facebook/mbart-large-50-many-to-many-mmt", # dutch
"no": "Confused404/eng-gmq-finetuned_v2-no", # norwegian # Alex's fine-tuned model
"pl": "Helsinki-NLP/opus-mt-en-sla", # polish
"pt": "facebook/mbart-large-50-many-to-many-mmt", # portuguese
"ro": "facebook/mbart-large-50-many-to-many-mmt", # romanian
"sk": "Helsinki-NLP/opus-mt-en-sk", # slovak
"sl": "alirezamsh/small100", # slovene
"sq": "alirezamsh/small100", # albanian
"sv": "Helsinki-NLP/opus-mt-en-sv", # swedish
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr" # turkish
}
# Cache loaded models/tokenizers
MODEL_CACHE = {}
def load_model(model_id: str, target_lang: str):
"""
Load & cache:
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
- alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM
- all others via MarianTokenizer & MarianMTModel
"""
if model_id not in MODEL_CACHE or model_id == "alirezamsh/small100":
if model_id.startswith("facebook/mbart"):
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
model = MBartForConditionalGeneration.from_pretrained(model_id)
elif model_id == "alirezamsh/small100":
tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
else:
tokenizer = MarianTokenizer.from_pretrained(model_id)
model = MarianMTModel.from_pretrained(model_id)
model.to("cpu")
MODEL_CACHE[model_id] = (tokenizer, model)
return MODEL_CACHE[model_id]
@app.post("/translate")
async def translate(request: Request):
payload = await request.json()
text = payload.get("text")
target_lang = payload.get("target_lang")
if not text or not target_lang:
return {"error": "Missing 'text' or 'target_lang'"}
model_id = MODEL_MAP.get(target_lang)
if not model_id:
return {"error": f"No model found for target language '{target_lang}'"}
try:
# chunk to safe length
safe_limit = get_max_word_length([target_lang])
chunks = chunk_text(text, safe_limit)
tokenizer, model = load_model(model_id, target_lang)
full_translation = []
for chunk in chunks:
if model_id == "Confused404/eng-gmq-finetuned_v2-no":
chunk = f">>nob<< {chunk}"
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return {"translation": " ".join(full_translation)}
except Exception as e:
return {"error": f"Translation failed: {e}"}
@app.get("/languages")
def list_languages():
return {"supported_languages": list(MODEL_MAP.keys())}
@app.get("/health")
def health():
return {"status": "ok"}
# Uvicorn startup for local testing
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)