Updated app.py and requirements.txt
Browse files- app.py +4 -4
- requirements.txt +0 -2
app.py
CHANGED
@@ -3,11 +3,11 @@ import torch
|
|
3 |
from PIL import Image
|
4 |
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
5 |
|
6 |
-
# Load model and processor
|
7 |
-
model = AutoModelForImageClassification.from_pretrained("shravvvv/SAG-ViT")
|
8 |
-
processor = AutoImageProcessor.from_pretrained("shravvvv/SAG-ViT")
|
9 |
|
10 |
-
# CIFAR-10 class labels
|
11 |
class_labels = [
|
12 |
'airplane', 'automobile', 'bird', 'cat', 'deer',
|
13 |
'dog', 'frog', 'horse', 'ship', 'truck'
|
|
|
3 |
from PIL import Image
|
4 |
from transformers import AutoModelForImageClassification, AutoImageProcessor
|
5 |
|
6 |
+
# Load model and processor with custom code enabled
|
7 |
+
model = AutoModelForImageClassification.from_pretrained("shravvvv/SAG-ViT", trust_remote_code=True)
|
8 |
+
processor = AutoImageProcessor.from_pretrained("shravvvv/SAG-ViT", trust_remote_code=True)
|
9 |
|
10 |
+
# Define CIFAR-10 class labels
|
11 |
class_labels = [
|
12 |
'airplane', 'automobile', 'bird', 'cat', 'deer',
|
13 |
'dog', 'frog', 'horse', 'ship', 'truck'
|
requirements.txt
CHANGED
@@ -3,8 +3,6 @@ pandas==2.2.3
|
|
3 |
matplotlib==3.7.5
|
4 |
seaborn==0.12.2
|
5 |
tqdm==4.66.4
|
6 |
-
psutil==5.9.3
|
7 |
-
pynvml==11.4.1
|
8 |
scikit-learn==1.2.2
|
9 |
torch==2.4.0
|
10 |
torch-geometric==2.6.1
|
|
|
3 |
matplotlib==3.7.5
|
4 |
seaborn==0.12.2
|
5 |
tqdm==4.66.4
|
|
|
|
|
6 |
scikit-learn==1.2.2
|
7 |
torch==2.4.0
|
8 |
torch-geometric==2.6.1
|