File size: 4,954 Bytes
efb198f 447b422 b8c0a2d 447b422 efb198f 282b03d efb198f 447b422 07ea3d5 efb198f 6df8ecd efb198f 4ea12a6 b302d2e 4ea12a6 efb198f b8c0a2d efb198f 282b03d 447b422 b8c0a2d 447b422 1d177ae 447b422 b8c0a2d 282b03d b8c0a2d 447b422 b8c0a2d 447b422 efb198f b8c0a2d 447b422 efb198f 447b422 efb198f 6df8ecd efb198f 6df8ecd 447b422 07ea3d5 447b422 07ea3d5 282b03d 07ea3d5 447b422 07ea3d5 a40fe7a b302d2e a40fe7a 447b422 07ea3d5 447b422 07ea3d5 6df8ecd 447b422 6df8ecd efb198f 6df8ecd efb198f b8c0a2d efb198f b8c0a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from fastapi import FastAPI, Request
from transformers import (
MarianMTModel,
MarianTokenizer,
MBartForConditionalGeneration,
MBart50TokenizerFast,
AutoTokenizer,
AutoModelForSeq2SeqLM
)
import torch
from tokenization_small100 import SMALL100Tokenizer
# import your chunking helpers
from chunking import get_max_word_length, chunk_text
app = FastAPI()
# Map target languages to Hugging Face model IDs
MODEL_MAP = {
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg", # bulgarian
"cs": "Helsinki-NLP/opus-mt-en-cs", # czech
"da": "Helsinki-NLP/opus-mt-en-da", # danish
"de": "Helsinki-NLP/opus-mt-en-de", # german
"el": "Helsinki-NLP/opus-mt-tc-big-en-el", # greek
"es": "Helsinki-NLP/opus-mt-tc-big-en-es", # spanish
"et": "Helsinki-NLP/opus-mt-tc-big-en-et", # estonian
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi", # finnish
"fr": "Helsinki-NLP/opus-mt-en-fr", # french
"hr": "facebook/mbart-large-50-many-to-many-mmt", # croatian
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu", # hungarian
"is": "mkorada/opus-mt-en-is-finetuned-v4", # icelandic # Manas's fine-tuned model
"it": "Helsinki-NLP/opus-mt-tc-big-en-it", # italian
"lb": "alirezamsh/small100", # luxembourgish # small100
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt", # lithuanian
"lv": "facebook/mbart-large-50-many-to-many-mmt", # latvian
"cnr": "Helsinki-NLP/opus-mt-tc-base-en-sh", # montegrin
"mk": "Helsinki-NLP/opus-mt-en-mk", # macedonian
# "nb": "facebook/mbart-large-50-many-to-many-mmt", # norwegian
"nl": "facebook/mbart-large-50-many-to-many-mmt", # dutch
"no": "Confused404/eng-gmq-finetuned_v2-no", # norwegian # Alex's fine-tuned model
"pl": "Helsinki-NLP/opus-mt-en-sla", # polish
"pt": "facebook/mbart-large-50-many-to-many-mmt", # portuguese
"ro": "facebook/mbart-large-50-many-to-many-mmt", # romanian
"sk": "Helsinki-NLP/opus-mt-en-sk", # slovak
"sl": "alirezamsh/small100", # slovene
"sq": "alirezamsh/small100", # albanian
"sv": "Helsinki-NLP/opus-mt-en-sv", # swedish
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr" # turkish
}
# Cache loaded models/tokenizers
MODEL_CACHE = {}
def load_model(model_id: str, target_lang: str):
"""
Load & cache:
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
- alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM
- all others via MarianTokenizer & MarianMTModel
"""
if model_id not in MODEL_CACHE or model_id == "alirezamsh/small100":
if model_id.startswith("facebook/mbart"):
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
model = MBartForConditionalGeneration.from_pretrained(model_id)
elif model_id == "alirezamsh/small100":
tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
else:
tokenizer = MarianTokenizer.from_pretrained(model_id)
model = MarianMTModel.from_pretrained(model_id)
model.to("cpu")
MODEL_CACHE[model_id] = (tokenizer, model)
return MODEL_CACHE[model_id]
@app.post("/translate")
async def translate(request: Request):
payload = await request.json()
text = payload.get("text")
target_lang = payload.get("target_lang")
if not text or not target_lang:
return {"error": "Missing 'text' or 'target_lang'"}
model_id = MODEL_MAP.get(target_lang)
if not model_id:
return {"error": f"No model found for target language '{target_lang}'"}
try:
# chunk to safe length
safe_limit = get_max_word_length([target_lang])
chunks = chunk_text(text, safe_limit)
tokenizer, model = load_model(model_id, target_lang)
full_translation = []
for chunk in chunks:
if model_id == "Confused404/eng-gmq-finetuned_v2-no":
chunk = f">>nob<< {chunk}"
if model_id == "Helsinki-NLP/opus-mt-tc-base-en-sh":
chunk = f">>cnr<< {chunk}"
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
return {"translation": " ".join(full_translation)}
except Exception as e:
return {"error": f"Translation failed: {e}"}
@app.get("/languages")
def list_languages():
return {"supported_languages": list(MODEL_MAP.keys())}
@app.get("/health")
def health():
return {"status": "ok"}
# Uvicorn startup for local testing
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|