mayacou commited on
Commit
447b422
Β·
verified Β·
1 Parent(s): 7bc869d

add fixes for mbart

Browse files
Files changed (1) hide show
  1. app.py +32 -22
app.py CHANGED
@@ -1,8 +1,13 @@
1
  from fastapi import FastAPI, Request
2
- from transformers import MarianMTModel, MarianTokenizer
 
 
 
 
 
3
  import torch
4
 
5
- # import chunking service
6
  from chunking import get_max_word_length, chunk_text
7
 
8
  app = FastAPI()
@@ -43,19 +48,30 @@ MODEL_MAP = {
43
  MODEL_CACHE = {}
44
 
45
  # βœ… Load Hugging Face model (Helsinki or Small100)
46
- def load_model(model_id):
 
 
 
 
 
47
  if model_id not in MODEL_CACHE:
48
- tokenizer = MarianTokenizer.from_pretrained(model_id)
49
- model = MarianMTModel.from_pretrained(model_id).to("cpu")
 
 
 
 
 
50
  MODEL_CACHE[model_id] = (tokenizer, model)
51
  return MODEL_CACHE[model_id]
52
 
 
53
  # βœ… POST /translate
54
  @app.post("/translate")
55
  async def translate(request: Request):
56
- data = await request.json()
57
- text = data.get("text")
58
- target_lang = data.get("target_lang")
59
 
60
  if not text or not target_lang:
61
  return {"error": "Missing 'text' or 'target_lang'"}
@@ -64,31 +80,25 @@ async def translate(request: Request):
64
  if not model_id:
65
  return {"error": f"No model found for target language '{target_lang}'"}
66
 
67
- # Facebook/mbart placeholder check
68
- if model_id.startswith("facebook/"):
69
- return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
70
-
71
  try:
72
- # 1. figure out your safe word limit for this language
73
  safe_limit = get_max_word_length([target_lang])
 
74
 
75
- # 2. break the input up into chunks
76
- chunks = chunk_text(text, safe_limit)
77
-
78
- # 3. translate each chunk and collect results
79
  tokenizer, model = load_model(model_id)
80
  full_translation = []
 
81
  for chunk in chunks:
82
- inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True).to(model.device)
 
83
  outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
84
  full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
85
 
86
- # 4. re-join the translated pieces
87
- joined = " ".join(full_translation)
88
- return {"translation": joined}
89
 
90
  except Exception as e:
91
- return {"error": f"Translation failed: {str(e)}"}
 
92
 
93
  # βœ… GET /languages
94
  @app.get("/languages")
 
1
  from fastapi import FastAPI, Request
2
+ from transformers import (
3
+ MarianMTModel,
4
+ MarianTokenizer,
5
+ MBartForConditionalGeneration,
6
+ MBart50TokenizerFast
7
+ )
8
  import torch
9
 
10
+ # import your chunking helpers
11
  from chunking import get_max_word_length, chunk_text
12
 
13
  app = FastAPI()
 
48
  MODEL_CACHE = {}
49
 
50
  # βœ… Load Hugging Face model (Helsinki or Small100)
51
+ def load_model(model_id: str):
52
+ """
53
+ Load & cache either:
54
+ - MBart50 (facebook/mbart-*)
55
+ - MarianMT otherwise
56
+ """
57
  if model_id not in MODEL_CACHE:
58
+ if model_id.startswith("facebook/mbart"):
59
+ tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
60
+ model = MBartForConditionalGeneration.from_pretrained(model_id)
61
+ else:
62
+ tokenizer = MarianTokenizer.from_pretrained(model_id)
63
+ model = MarianMTModel.from_pretrained(model_id)
64
+ model.to("cpu")
65
  MODEL_CACHE[model_id] = (tokenizer, model)
66
  return MODEL_CACHE[model_id]
67
 
68
+
69
  # βœ… POST /translate
70
  @app.post("/translate")
71
  async def translate(request: Request):
72
+ payload = await request.json()
73
+ text = payload.get("text")
74
+ target_lang = payload.get("target_lang")
75
 
76
  if not text or not target_lang:
77
  return {"error": "Missing 'text' or 'target_lang'"}
 
80
  if not model_id:
81
  return {"error": f"No model found for target language '{target_lang}'"}
82
 
 
 
 
 
83
  try:
84
+ # chunk to safe length
85
  safe_limit = get_max_word_length([target_lang])
86
+ chunks = chunk_text(text, safe_limit)
87
 
 
 
 
 
88
  tokenizer, model = load_model(model_id)
89
  full_translation = []
90
+
91
  for chunk in chunks:
92
+ inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True)
93
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
94
  outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
95
  full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
96
 
97
+ return {"translation": " ".join(full_translation)}
 
 
98
 
99
  except Exception as e:
100
+ return {"error": f"Translation failed: {e}"}
101
+
102
 
103
  # βœ… GET /languages
104
  @app.get("/languages")