mayacou commited on
Commit
07ea3d5
·
verified ·
1 Parent(s): be484c1

Update app.py

Browse files

added chunking service references

Files changed (1) hide show
  1. app.py +21 -3
app.py CHANGED
@@ -2,6 +2,9 @@ from fastapi import FastAPI, Request
2
  from transformers import MarianMTModel, MarianTokenizer
3
  import torch
4
 
 
 
 
5
  app = FastAPI()
6
 
7
  # Map target languages to Hugging Face model IDs
@@ -60,14 +63,29 @@ async def translate(request: Request):
60
  if not model_id:
61
  return {"error": f"No model found for target language '{target_lang}'"}
62
 
 
63
  if model_id.startswith("facebook/"):
64
  return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
65
 
66
  try:
 
 
 
 
 
 
 
67
  tokenizer, model = load_model(model_id)
68
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
69
- outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
70
- return {"translation": tokenizer.decode(outputs[0], skip_special_tokens=True)}
 
 
 
 
 
 
 
71
  except Exception as e:
72
  return {"error": f"Translation failed: {str(e)}"}
73
 
 
2
  from transformers import MarianMTModel, MarianTokenizer
3
  import torch
4
 
5
+ # import chunking service
6
+ from chunking import get_max_word_length, chunk_text
7
+
8
  app = FastAPI()
9
 
10
  # Map target languages to Hugging Face model IDs
 
63
  if not model_id:
64
  return {"error": f"No model found for target language '{target_lang}'"}
65
 
66
+ # Facebook/mbart placeholder check
67
  if model_id.startswith("facebook/"):
68
  return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
69
 
70
  try:
71
+ # 1. figure out your safe word limit for this language
72
+ safe_limit = get_max_word_length([target_lang])
73
+
74
+ # 2. break the input up into chunks
75
+ chunks = chunk_text(text, safe_limit)
76
+
77
+ # 3. translate each chunk and collect results
78
  tokenizer, model = load_model(model_id)
79
+ full_translation = []
80
+ for chunk in chunks:
81
+ inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True).to(model.device)
82
+ outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
83
+ full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
84
+
85
+ # 4. re-join the translated pieces
86
+ joined = " ".join(full_translation)
87
+ return {"translation": joined}
88
+
89
  except Exception as e:
90
  return {"error": f"Translation failed: {str(e)}"}
91