mayacou commited on
Commit
b8c0a2d
·
verified ·
1 Parent(s): da88c0f

correct issue with small100

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -3,7 +3,9 @@ from transformers import (
3
  MarianMTModel,
4
  MarianTokenizer,
5
  MBartForConditionalGeneration,
6
- MBart50TokenizerFast
 
 
7
  )
8
  import torch
9
 
@@ -25,16 +27,16 @@ MODEL_MAP = {
25
  "fr": "Helsinki-NLP/opus-mt-en-fr",
26
  "hr": "facebook/mbart-large-50-many-to-many-mmt",
27
  "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
28
- "is": "mkorada/opus-mt-en-is-finetuned-v4", #Manas's fine-tuned model
29
  "it": "Helsinki-NLP/opus-mt-tc-big-en-it",
30
- "lb": "alirezamsh/small100",
31
  "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
32
  "lv": "facebook/mbart-large-50-many-to-many-mmt",
33
- "me": "Helsinki-NLP/opus-mt-tc-base-en-sh"
34
  "mk": "Helsinki-NLP/opus-mt-en-mk",
35
- "nb": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
36
  "nl": "facebook/mbart-large-50-many-to-many-mmt",
37
- "no": "Confused404/eng-gmq-finetuned_v2-no", #Alex's fine-tuned model
38
  "pl": "Helsinki-NLP/opus-mt-en-sla",
39
  "pt": "facebook/mbart-large-50-many-to-many-mmt",
40
  "ro": "facebook/mbart-large-50-many-to-many-mmt",
@@ -45,29 +47,32 @@ MODEL_MAP = {
45
  "tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
46
  }
47
 
48
-
49
  MODEL_CACHE = {}
50
 
51
- # ✅ Load Hugging Face model (Helsinki or Small100)
52
  def load_model(model_id: str):
53
  """
54
- Load & cache either:
55
- - MBart50 (facebook/mbart-*)
56
- - MarianMT otherwise
 
57
  """
58
  if model_id not in MODEL_CACHE:
59
  if model_id.startswith("facebook/mbart"):
60
  tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
61
  model = MBartForConditionalGeneration.from_pretrained(model_id)
 
 
 
62
  else:
63
  tokenizer = MarianTokenizer.from_pretrained(model_id)
64
  model = MarianMTModel.from_pretrained(model_id)
 
65
  model.to("cpu")
66
  MODEL_CACHE[model_id] = (tokenizer, model)
67
- return MODEL_CACHE[model_id]
68
 
 
69
 
70
- # ✅ POST /translate
71
  @app.post("/translate")
72
  async def translate(request: Request):
73
  payload = await request.json()
@@ -100,18 +105,15 @@ async def translate(request: Request):
100
  except Exception as e:
101
  return {"error": f"Translation failed: {e}"}
102
 
103
-
104
- # ✅ GET /languages
105
  @app.get("/languages")
106
  def list_languages():
107
  return {"supported_languages": list(MODEL_MAP.keys())}
108
 
109
- # ✅ GET /health
110
  @app.get("/health")
111
  def health():
112
  return {"status": "ok"}
113
 
114
- # Uvicorn startup (required by Hugging Face)
115
- import uvicorn
116
  if __name__ == "__main__":
117
- uvicorn.run("app:app", host="0.0.0.0", port=7860)
 
 
3
  MarianMTModel,
4
  MarianTokenizer,
5
  MBartForConditionalGeneration,
6
+ MBart50TokenizerFast,
7
+ AutoTokenizer,
8
+ AutoModelForSeq2SeqLM
9
  )
10
  import torch
11
 
 
27
  "fr": "Helsinki-NLP/opus-mt-en-fr",
28
  "hr": "facebook/mbart-large-50-many-to-many-mmt",
29
  "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
30
+ "is": "mkorada/opus-mt-en-is-finetuned-v4", # Manas's fine-tuned model
31
  "it": "Helsinki-NLP/opus-mt-tc-big-en-it",
32
+ "lb": "alirezamsh/small100", # small100
33
  "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
34
  "lv": "facebook/mbart-large-50-many-to-many-mmt",
35
+ "me": "Helsinki-NLP/opus-mt-tc-base-en-sh",
36
  "mk": "Helsinki-NLP/opus-mt-en-mk",
37
+ "nb": "facebook/mbart-large-50-many-to-many-mmt",
38
  "nl": "facebook/mbart-large-50-many-to-many-mmt",
39
+ "no": "Confused404/eng-gmq-finetuned_v2-no", # Alex's fine-tuned model
40
  "pl": "Helsinki-NLP/opus-mt-en-sla",
41
  "pt": "facebook/mbart-large-50-many-to-many-mmt",
42
  "ro": "facebook/mbart-large-50-many-to-many-mmt",
 
47
  "tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
48
  }
49
 
50
+ # Cache loaded models/tokenizers
51
  MODEL_CACHE = {}
52
 
 
53
  def load_model(model_id: str):
54
  """
55
+ Load & cache:
56
+ - facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
57
+ - alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM
58
+ - all others via MarianTokenizer & MarianMTModel
59
  """
60
  if model_id not in MODEL_CACHE:
61
  if model_id.startswith("facebook/mbart"):
62
  tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
63
  model = MBartForConditionalGeneration.from_pretrained(model_id)
64
+ elif model_id == "alirezamsh/small100":
65
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
66
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
67
  else:
68
  tokenizer = MarianTokenizer.from_pretrained(model_id)
69
  model = MarianMTModel.from_pretrained(model_id)
70
+
71
  model.to("cpu")
72
  MODEL_CACHE[model_id] = (tokenizer, model)
 
73
 
74
+ return MODEL_CACHE[model_id]
75
 
 
76
  @app.post("/translate")
77
  async def translate(request: Request):
78
  payload = await request.json()
 
105
  except Exception as e:
106
  return {"error": f"Translation failed: {e}"}
107
 
 
 
108
  @app.get("/languages")
109
  def list_languages():
110
  return {"supported_languages": list(MODEL_MAP.keys())}
111
 
 
112
  @app.get("/health")
113
  def health():
114
  return {"status": "ok"}
115
 
116
+ # Uvicorn startup for local testing
 
117
  if __name__ == "__main__":
118
+ import uvicorn
119
+ uvicorn.run("app:app", host="0.0.0.0", port=7860)