|
from fastapi import FastAPI, Request |
|
from transformers import MarianMTModel, MarianTokenizer |
|
import torch |
|
|
|
app = FastAPI() |
|
|
|
MODEL_MAP = { |
|
"fr": "Helsinki-NLP/opus-mt-en-fr", |
|
"de": "Helsinki-NLP/opus-mt-en-de" |
|
} |
|
|
|
MODEL_CACHE = {} |
|
|
|
def load_model(model_id): |
|
if model_id not in MODEL_CACHE: |
|
tokenizer = MarianTokenizer.from_pretrained(model_id) |
|
model = MarianMTModel.from_pretrained(model_id).to("cpu") |
|
MODEL_CACHE[model_id] = (tokenizer, model) |
|
return MODEL_CACHE[model_id] |
|
|
|
@app.post("/translate") |
|
async def translate(request: Request): |
|
data = await request.json() |
|
text = data.get("text") |
|
target_lang = data.get("target_lang") |
|
|
|
if not text or not target_lang: |
|
return {"error": "Missing text or target_lang"} |
|
|
|
model_id = MODEL_MAP.get(target_lang) |
|
if not model_id: |
|
return {"error": f"No model for '{target_lang}'"} |
|
|
|
tokenizer, model = load_model(model_id) |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device) |
|
outputs = model.generate(**inputs) |
|
return {"translation": tokenizer.decode(outputs[0], skip_special_tokens=True)} |
|
|
|
|
|
import uvicorn |
|
if __name__ == "__main__": |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860) |
|
|