Nadil Karunarathna commited on
Commit
b9e0b01
·
1 Parent(s): 991fe21
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -4,7 +4,7 @@ import re
4
 
5
  model = None
6
  tokenizer = None
7
- device = "cpu"
8
 
9
  def init():
10
  from transformers import MT5ForConditionalGeneration, T5TokenizerFast
@@ -13,8 +13,8 @@ def init():
13
  global model, tokenizer
14
 
15
  hf_token = os.environ.get("HF_TOKEN")
16
- model_path = "lm-spell/mt5-base-ft-ssc"
17
- model = MT5ForConditionalGeneration.from_pretrained(model_path, token=hf_token).to(device)
18
  tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
19
  tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
20
 
@@ -29,7 +29,7 @@ def correct(text):
29
  padding='do_not_pad',
30
  max_length=1024
31
  )
32
- inputs = {k: v.to(device) for k, v in inputs.items()}
33
 
34
  with torch.inference_mode():
35
  outputs = model.generate(
 
4
 
5
  model = None
6
  tokenizer = None
7
+ # device = "cpu"
8
 
9
  def init():
10
  from transformers import MT5ForConditionalGeneration, T5TokenizerFast
 
13
  global model, tokenizer
14
 
15
  hf_token = os.environ.get("HF_TOKEN")
16
+
17
+ model = MT5ForConditionalGeneration.from_pretrained("lm-spell/mt5-base-ft-ssc", token=hf_token)
18
  tokenizer = T5TokenizerFast.from_pretrained("google/mt5-base")
19
  tokenizer.add_special_tokens({'additional_special_tokens': ['<ZWJ>']})
20
 
 
29
  padding='do_not_pad',
30
  max_length=1024
31
  )
32
+ # inputs = {k: v.to(device) for k, v in inputs.items()}
33
 
34
  with torch.inference_mode():
35
  outputs = model.generate(