seanp03 commited on
Commit
21a49a0
·
verified ·
1 Parent(s): b953f42

app.py small100 fix

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -8,6 +8,7 @@ from transformers import (
8
  AutoModelForSeq2SeqLM
9
  )
10
  import torch
 
11
 
12
  # import your chunking helpers
13
  from chunking import get_max_word_length, chunk_text
@@ -50,7 +51,7 @@ MODEL_MAP = {
50
  # Cache loaded models/tokenizers
51
  MODEL_CACHE = {}
52
 
53
- def load_model(model_id: str):
54
  """
55
  Load & cache:
56
  - facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
@@ -62,7 +63,7 @@ def load_model(model_id: str):
62
  tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
63
  model = MBartForConditionalGeneration.from_pretrained(model_id)
64
  elif model_id == "alirezamsh/small100":
65
- tokenizer = AutoTokenizer.from_pretrained(model_id)
66
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
67
  else:
68
  tokenizer = MarianTokenizer.from_pretrained(model_id)
@@ -91,7 +92,7 @@ async def translate(request: Request):
91
  safe_limit = get_max_word_length([target_lang])
92
  chunks = chunk_text(text, safe_limit)
93
 
94
- tokenizer, model = load_model(model_id)
95
  full_translation = []
96
 
97
  for chunk in chunks:
 
8
  AutoModelForSeq2SeqLM
9
  )
10
  import torch
11
+ from tokenization_small100 import SMALL100Tokenizer
12
 
13
  # import your chunking helpers
14
  from chunking import get_max_word_length, chunk_text
 
51
  # Cache loaded models/tokenizers
52
  MODEL_CACHE = {}
53
 
54
+ def load_model(model_id: str, target_lang: str):
55
  """
56
  Load & cache:
57
  - facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
 
63
  tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
64
  model = MBartForConditionalGeneration.from_pretrained(model_id)
65
  elif model_id == "alirezamsh/small100":
66
+ tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang)
67
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
68
  else:
69
  tokenizer = MarianTokenizer.from_pretrained(model_id)
 
92
  safe_limit = get_max_word_length([target_lang])
93
  chunks = chunk_text(text, safe_limit)
94
 
95
+ tokenizer, model = load_model(model_id, target_lang)
96
  full_translation = []
97
 
98
  for chunk in chunks: