nico-x commited on
Commit
ad79d22
·
1 Parent(s): 3dc14f9

clean model loaded from hf

Browse files
Files changed (3) hide show
  1. .gitattributes +0 -1
  2. .gitignore +2 -2
  3. gradio_ui/gradio_app.py +4 -1
.gitattributes DELETED
@@ -1 +0,0 @@
1
- *.pt filter=lfs diff=lfs merge=lfs -text
 
 
.gitignore CHANGED
@@ -13,8 +13,8 @@ venv/
13
  .DS_Store
14
 
15
  # PyTorch checkpoints
16
- # *.pt
17
- #!checkpoints/transformer_mnist.pt # allow the model to be pushed to HF Space
18
 
19
  # Gradio session files
20
  gradio_cached_examples/
 
13
  .DS_Store
14
 
15
  # PyTorch checkpoints
16
+ checkpoints/
17
+ *.pt
18
 
19
  # Gradio session files
20
  gradio_cached_examples/
gradio_ui/gradio_app.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  import torch
6
  import numpy as np
7
  from PIL import Image, ImageDraw
 
8
 
9
  from model.model import ImageToDigitTransformer
10
  from utils.tokenizer import START, FINISH, decode
@@ -12,8 +13,10 @@ from gradio_ui.preprocess import preprocess_canvases
12
 
13
  # Load model
14
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
 
 
15
  model = ImageToDigitTransformer(vocab_size=13).to(device)
16
- model.load_state_dict(torch.load("checkpoints/transformer_mnist.pt", map_location=device))
17
  model.eval()
18
 
19
  def split_into_quadrants(image):
 
5
  import torch
6
  import numpy as np
7
  from PIL import Image, ImageDraw
8
+ from huggingface_hub import hf_hub_download
9
 
10
  from model.model import ImageToDigitTransformer
11
  from utils.tokenizer import START, FINISH, decode
 
13
 
14
  # Load model
15
  device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
16
+
17
+ model_path = hf_hub_download(repo_id="nico-x/vit-transformer-mnist", filename="transformer_mnist.pt")
18
  model = ImageToDigitTransformer(vocab_size=13).to(device)
19
+ model.load_state_dict(torch.load(model_path, map_location=device))
20
  model.eval()
21
 
22
  def split_into_quadrants(image):