import io import time from typing import List, Literal from fastapi import FastAPI from pydantic import BaseModel from enum import Enum from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration import torch app = FastAPI() device = torch.device("cpu") class TranslationRequest(BaseModel): user_input: str source_lang: str target_lang: str def load_model(pretrained_model: str = "facebook/m2m100_1.2B", cache_dir: str = "models/"): tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir) model = M2M100ForConditionalGeneration.from_pretrained(pretrained_model, cache_dir=cache_dir).to(device) model.eval() return tokenizer, model @app.post("/translate") async def translate(request: TranslationRequest): time_start = time.time() tokenizer, model = load_model() src_lang = request.source_lang trg_lang = request.target_lang tokenizer.src_lang = src_lang with torch.no_grad(): encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device) generated_tokens = model.generate( **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang) ) translated_text = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True )[0] time_end = time.time() response = {"translation": translated_text, "computation_time": round((time_end - time_start), 3)} return response if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)