from fastapi import FastAPI, Request from transformers import ( MarianMTModel, MarianTokenizer, MBartForConditionalGeneration, MBart50TokenizerFast, AutoModelForSeq2SeqLM ) import torch from tokenization_small100 import SMALL100Tokenizer # import your chunking helpers from chunking import get_max_word_length, chunk_text app = FastAPI() # Map target languages to Hugging Face model IDs MODEL_MAP = { "bg": "Helsinki-NLP/opus-mt-tc-big-en-bg", "cs": "Helsinki-NLP/opus-mt-en-cs", "da": "Helsinki-NLP/opus-mt-en-da", "de": "Helsinki-NLP/opus-mt-en-de", "el": "Helsinki-NLP/opus-mt-tc-big-en-el", "es": "Helsinki-NLP/opus-mt-tc-big-en-es", "et": "Helsinki-NLP/opus-mt-tc-big-en-et", "fi": "Helsinki-NLP/opus-mt-tc-big-en-fi", "fr": "Helsinki-NLP/opus-mt-en-fr", "hr": "facebook/mbart-large-50-many-to-many-mmt", "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu", "is": "mkorada/opus-mt-en-is-finetuned-v4", "it": "Helsinki-NLP/opus-mt-tc-big-en-it", "lb": "alirezamsh/small100", "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt", "lv": "facebook/mbart-large-50-many-to-many-mmt", "cnr": "Helsinki-NLP/opus-mt-tc-base-en-sh", "mk": "Helsinki-NLP/opus-mt-en-mk", "nl": "facebook/mbart-large-50-many-to-many-mmt", "no": "Confused404/eng-gmq-finetuned_v2-no", "pl": "Helsinki-NLP/opus-mt-en-sla", "pt": "facebook/mbart-large-50-many-to-many-mmt", "ro": "facebook/mbart-large-50-many-to-many-mmt", "sk": "Helsinki-NLP/opus-mt-en-sk", "sl": "alirezamsh/small100", "sq": "alirezamsh/small100", "sv": "Helsinki-NLP/opus-mt-en-sv", "tr": "Helsinki-NLP/opus-mt-tc-big-en-tr" } # Cache loaded models/tokenizers MODEL_CACHE = {} def load_model(model_id: str, target_lang: str): """ Load & cache: - facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration - alirezamsh/small100 via SMALL100Tokenizer & AutoModelForSeq2SeqLM - all others via MarianTokenizer & MarianMTModel """ # Always reload small100 so we can pass a new tgt_lang if model_id not in MODEL_CACHE or model_id == "alirezamsh/small100": if model_id.startswith("facebook/mbart"): tokenizer = MBart50TokenizerFast.from_pretrained(model_id) # ── MBART: always translate FROM English tokenizer.src_lang = "en_XX" model = MBartForConditionalGeneration.from_pretrained(model_id) elif model_id == "alirezamsh/small100": tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang) model = AutoModelForSeq2SeqLM.from_pretrained(model_id) else: tokenizer = MarianTokenizer.from_pretrained(model_id) model = MarianMTModel.from_pretrained(model_id) model.to("cpu") MODEL_CACHE[model_id] = (tokenizer, model) return MODEL_CACHE[model_id] @app.post("/translate") async def translate(request: Request): payload = await request.json() text = payload.get("text") target_lang = payload.get("target_lang") if not text or not target_lang: return {"error": "Missing 'text' or 'target_lang'"} model_id = MODEL_MAP.get(target_lang) if not model_id: return {"error": f"No model found for target language '{target_lang}'"} try: # chunk to safe length safe_limit = get_max_word_length([target_lang]) chunks = chunk_text(text, safe_limit) tokenizer, model = load_model(model_id, target_lang) is_mbart = model_id.startswith("facebook/mbart") full_translation = [] for chunk in chunks: # special-prefix hacks for nor/cnr if model_id == "Confused404/eng-gmq-finetuned_v2-no": chunk = f">>nob<< {chunk}" if model_id == "Helsinki-NLP/opus-mt-tc-base-en-sh": chunk = f">>cnr<< {chunk}" # tokenize inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True) inputs = {k: v.to(model.device) for k, v in inputs.items()} # generate if is_mbart: # build e.g. "de_DE", "es_XX", etc. lang_code = f"{target_lang}_{target_lang.upper()}" if target_lang == "nl" or target_lang == "pt": lang_code = f"{target_lang}_XX" bos_id = tokenizer.lang_code_to_id[lang_code] outputs = model.generate( **inputs, forced_bos_token_id=bos_id, num_beams=5, length_penalty=1.2, early_stopping=True ) else: outputs = model.generate( **inputs, num_beams=5, length_penalty=1.2, early_stopping=True ) full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True)) return {"translation": " ".join(full_translation)} except Exception as e: return {"error": f"Translation failed: {e}"} @app.get("/languages") def list_languages(): return {"supported_languages": list(MODEL_MAP.keys())} @app.get("/health") def health(): return {"status": "ok"} if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860)