app.py small100 fix (#6)
Browse files- app.py small100 fix (21a49a060dde151d2a200777c5699440cb29cf6a)
app.py
CHANGED
@@ -8,6 +8,7 @@ from transformers import (
|
|
8 |
AutoModelForSeq2SeqLM
|
9 |
)
|
10 |
import torch
|
|
|
11 |
|
12 |
# import your chunking helpers
|
13 |
from chunking import get_max_word_length, chunk_text
|
@@ -50,7 +51,7 @@ MODEL_MAP = {
|
|
50 |
# Cache loaded models/tokenizers
|
51 |
MODEL_CACHE = {}
|
52 |
|
53 |
-
def load_model(model_id: str):
|
54 |
"""
|
55 |
Load & cache:
|
56 |
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
|
@@ -62,7 +63,7 @@ def load_model(model_id: str):
|
|
62 |
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
|
63 |
model = MBartForConditionalGeneration.from_pretrained(model_id)
|
64 |
elif model_id == "alirezamsh/small100":
|
65 |
-
tokenizer =
|
66 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
67 |
else:
|
68 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
@@ -91,7 +92,7 @@ async def translate(request: Request):
|
|
91 |
safe_limit = get_max_word_length([target_lang])
|
92 |
chunks = chunk_text(text, safe_limit)
|
93 |
|
94 |
-
tokenizer, model = load_model(model_id)
|
95 |
full_translation = []
|
96 |
|
97 |
for chunk in chunks:
|
|
|
8 |
AutoModelForSeq2SeqLM
|
9 |
)
|
10 |
import torch
|
11 |
+
from tokenization_small100 import SMALL100Tokenizer
|
12 |
|
13 |
# import your chunking helpers
|
14 |
from chunking import get_max_word_length, chunk_text
|
|
|
51 |
# Cache loaded models/tokenizers
|
52 |
MODEL_CACHE = {}
|
53 |
|
54 |
+
def load_model(model_id: str, target_lang: str):
|
55 |
"""
|
56 |
Load & cache:
|
57 |
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
|
|
|
63 |
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
|
64 |
model = MBartForConditionalGeneration.from_pretrained(model_id)
|
65 |
elif model_id == "alirezamsh/small100":
|
66 |
+
tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang)
|
67 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
68 |
else:
|
69 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
|
|
92 |
safe_limit = get_max_word_length([target_lang])
|
93 |
chunks = chunk_text(text, safe_limit)
|
94 |
|
95 |
+
tokenizer, model = load_model(model_id, target_lang)
|
96 |
full_translation = []
|
97 |
|
98 |
for chunk in chunks:
|