shravvvv commited on
Commit
024824d
·
1 Parent(s): a63d916

Updated app.py and requirements.txt

Browse files
Files changed (2) hide show
  1. app.py +4 -4
  2. requirements.txt +0 -2
app.py CHANGED
@@ -3,11 +3,11 @@ import torch
3
  from PIL import Image
4
  from transformers import AutoModelForImageClassification, AutoImageProcessor
5
 
6
- # Load model and processor
7
- model = AutoModelForImageClassification.from_pretrained("shravvvv/SAG-ViT")
8
- processor = AutoImageProcessor.from_pretrained("shravvvv/SAG-ViT")
9
 
10
- # CIFAR-10 class labels
11
  class_labels = [
12
  'airplane', 'automobile', 'bird', 'cat', 'deer',
13
  'dog', 'frog', 'horse', 'ship', 'truck'
 
3
  from PIL import Image
4
  from transformers import AutoModelForImageClassification, AutoImageProcessor
5
 
6
+ # Load model and processor with custom code enabled
7
+ model = AutoModelForImageClassification.from_pretrained("shravvvv/SAG-ViT", trust_remote_code=True)
8
+ processor = AutoImageProcessor.from_pretrained("shravvvv/SAG-ViT", trust_remote_code=True)
9
 
10
+ # Define CIFAR-10 class labels
11
  class_labels = [
12
  'airplane', 'automobile', 'bird', 'cat', 'deer',
13
  'dog', 'frog', 'horse', 'ship', 'truck'
requirements.txt CHANGED
@@ -3,8 +3,6 @@ pandas==2.2.3
3
  matplotlib==3.7.5
4
  seaborn==0.12.2
5
  tqdm==4.66.4
6
- psutil==5.9.3
7
- pynvml==11.4.1
8
  scikit-learn==1.2.2
9
  torch==2.4.0
10
  torch-geometric==2.6.1
 
3
  matplotlib==3.7.5
4
  seaborn==0.12.2
5
  tqdm==4.66.4
 
 
6
  scikit-learn==1.2.2
7
  torch==2.4.0
8
  torch-geometric==2.6.1