mayacou commited on
Commit
6df8ecd
·
verified ·
1 Parent(s): 4b8f95c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -9
app.py CHANGED
@@ -4,13 +4,41 @@ import torch
4
 
5
  app = FastAPI()
6
 
 
7
  MODEL_MAP = {
 
 
 
 
 
 
 
 
8
  "fr": "Helsinki-NLP/opus-mt-en-fr",
9
- "de": "Helsinki-NLP/opus-mt-en-de"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  }
11
 
 
12
  MODEL_CACHE = {}
13
 
 
14
  def load_model(model_id):
15
  if model_id not in MODEL_CACHE:
16
  tokenizer = MarianTokenizer.from_pretrained(model_id)
@@ -18,6 +46,7 @@ def load_model(model_id):
18
  MODEL_CACHE[model_id] = (tokenizer, model)
19
  return MODEL_CACHE[model_id]
20
 
 
21
  @app.post("/translate")
22
  async def translate(request: Request):
23
  data = await request.json()
@@ -25,18 +54,34 @@ async def translate(request: Request):
25
  target_lang = data.get("target_lang")
26
 
27
  if not text or not target_lang:
28
- return {"error": "Missing text or target_lang"}
29
 
30
  model_id = MODEL_MAP.get(target_lang)
31
  if not model_id:
32
- return {"error": f"No model for '{target_lang}'"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- tokenizer, model = load_model(model_id)
35
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
36
- outputs = model.generate(**inputs)
37
- return {"translation": tokenizer.decode(outputs[0], skip_special_tokens=True)}
38
 
39
- # Required for FastAPI to run on HF Spaces
40
  import uvicorn
41
  if __name__ == "__main__":
42
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
4
 
5
  app = FastAPI()
6
 
7
+ # Map target languages to Hugging Face model IDs
8
  MODEL_MAP = {
9
+ "bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
10
+ "cs": "Helsinki-NLP/opus-mt-en-cs",
11
+ "da": "Helsinki-NLP/opus-mt-en-da",
12
+ "de": "Helsinki-NLP/opus-mt-en-de",
13
+ "el": "Helsinki-NLP/opus-mt-tc-big-en-el",
14
+ "es": "facebook/nllb-200-distilled-600M",
15
+ "et": "Helsinki-NLP/opus-mt-tc-big-en-et",
16
+ "fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
17
  "fr": "Helsinki-NLP/opus-mt-en-fr",
18
+ "hr": "facebook/mbart-large-50-many-to-many-mmt",
19
+ "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
20
+ "is": "facebook/nllb-200-distilled-600M",
21
+ "it": "facebook/nllb-200-distilled-600M",
22
+ "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
23
+ "lv": "facebook/mbart-large-50-many-to-many-mmt",
24
+ "mk": "facebook/nllb-200-distilled-600M",
25
+ "nb": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
26
+ "nl": "facebook/mbart-large-50-many-to-many-mmt",
27
+ "no": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
28
+ "pl": "facebook/nllb-200-distilled-600M",
29
+ "pt": "facebook/mbart-large-50-many-to-many-mmt",
30
+ "ro": "facebook/mbart-large-50-many-to-many-mmt",
31
+ "sk": "Helsinki-NLP/opus-mt-en-sk",
32
+ "sl": "alirezamsh/small100",
33
+ "sq": "alirezamsh/small100",
34
+ "sv": "Helsinki-NLP/opus-mt-en-sv",
35
+ "tr": "facebook/nllb-200-distilled-600M"
36
  }
37
 
38
+
39
  MODEL_CACHE = {}
40
 
41
+ # ✅ Load Hugging Face model (Helsinki or Small100)
42
  def load_model(model_id):
43
  if model_id not in MODEL_CACHE:
44
  tokenizer = MarianTokenizer.from_pretrained(model_id)
 
46
  MODEL_CACHE[model_id] = (tokenizer, model)
47
  return MODEL_CACHE[model_id]
48
 
49
+ # ✅ POST /translate
50
  @app.post("/translate")
51
  async def translate(request: Request):
52
  data = await request.json()
 
54
  target_lang = data.get("target_lang")
55
 
56
  if not text or not target_lang:
57
+ return {"error": "Missing 'text' or 'target_lang'"}
58
 
59
  model_id = MODEL_MAP.get(target_lang)
60
  if not model_id:
61
+ return {"error": f"No model found for target language '{target_lang}'"}
62
+
63
+ if model_id.startswith("facebook/"):
64
+ return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
65
+
66
+ try:
67
+ tokenizer, model = load_model(model_id)
68
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
69
+ outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
70
+ return {"translation": tokenizer.decode(outputs[0], skip_special_tokens=True)}
71
+ except Exception as e:
72
+ return {"error": f"Translation failed: {str(e)}"}
73
+
74
+ # ✅ GET /languages
75
+ @app.get("/languages")
76
+ def list_languages():
77
+ return {"supported_languages": list(MODEL_MAP.keys())}
78
 
79
+ # GET /health
80
+ @app.get("/health")
81
+ def health():
82
+ return {"status": "ok"}
83
 
84
+ # Uvicorn startup (required by Hugging Face)
85
  import uvicorn
86
  if __name__ == "__main__":
87
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)