Update app.py
Browse filesadded chunking service references
app.py
CHANGED
@@ -2,6 +2,9 @@ from fastapi import FastAPI, Request
|
|
2 |
from transformers import MarianMTModel, MarianTokenizer
|
3 |
import torch
|
4 |
|
|
|
|
|
|
|
5 |
app = FastAPI()
|
6 |
|
7 |
# Map target languages to Hugging Face model IDs
|
@@ -60,14 +63,29 @@ async def translate(request: Request):
|
|
60 |
if not model_id:
|
61 |
return {"error": f"No model found for target language '{target_lang}'"}
|
62 |
|
|
|
63 |
if model_id.startswith("facebook/"):
|
64 |
return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
|
65 |
|
66 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
tokenizer, model = load_model(model_id)
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
except Exception as e:
|
72 |
return {"error": f"Translation failed: {str(e)}"}
|
73 |
|
|
|
2 |
from transformers import MarianMTModel, MarianTokenizer
|
3 |
import torch
|
4 |
|
5 |
+
# import chunking service
|
6 |
+
from chunking import get_max_word_length, chunk_text
|
7 |
+
|
8 |
app = FastAPI()
|
9 |
|
10 |
# Map target languages to Hugging Face model IDs
|
|
|
63 |
if not model_id:
|
64 |
return {"error": f"No model found for target language '{target_lang}'"}
|
65 |
|
66 |
+
# Facebook/mbart placeholder check
|
67 |
if model_id.startswith("facebook/"):
|
68 |
return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
|
69 |
|
70 |
try:
|
71 |
+
# 1. figure out your safe word limit for this language
|
72 |
+
safe_limit = get_max_word_length([target_lang])
|
73 |
+
|
74 |
+
# 2. break the input up into chunks
|
75 |
+
chunks = chunk_text(text, safe_limit)
|
76 |
+
|
77 |
+
# 3. translate each chunk and collect results
|
78 |
tokenizer, model = load_model(model_id)
|
79 |
+
full_translation = []
|
80 |
+
for chunk in chunks:
|
81 |
+
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True).to(model.device)
|
82 |
+
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
|
83 |
+
full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
|
84 |
+
|
85 |
+
# 4. re-join the translated pieces
|
86 |
+
joined = " ".join(full_translation)
|
87 |
+
return {"translation": joined}
|
88 |
+
|
89 |
except Exception as e:
|
90 |
return {"error": f"Translation failed: {str(e)}"}
|
91 |
|