vykanand commited on
Commit
5dda5d9
Β·
verified Β·
1 Parent(s): 699fe26

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -1,13 +1,15 @@
1
  import gradio as gr
 
2
  from torchvision import models, transforms
3
  from PIL import Image
4
- import torch
 
5
 
6
- # Load the pre-trained MobileNetV2 model
7
  model = models.mobilenet_v2(pretrained=True)
8
  model.eval()
9
 
10
- # Image transformation for input
11
  transform = transforms.Compose([
12
  transforms.Resize(256),
13
  transforms.CenterCrop(224),
@@ -15,17 +17,27 @@ transform = transforms.Compose([
15
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
16
  ])
17
 
18
- def classify_image(image):
19
- # Apply transformations
20
- img_tensor = transform(image).unsqueeze(0)
 
 
 
 
 
 
21
 
22
  # Perform inference
23
  with torch.no_grad():
24
- outputs = model(img_tensor)
25
- _, predicted_class = torch.max(outputs, 1)
26
-
27
- return predicted_class.item()
 
 
28
 
29
  # Gradio interface
30
- interface = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="text")
31
- interface.launch()
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from torchvision import models, transforms
4
  from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
 
8
+ # Load pre-trained MobileNetV2 model (you can choose another model as needed)
9
  model = models.mobilenet_v2(pretrained=True)
10
  model.eval()
11
 
12
+ # Define the image transformation (resize, normalization)
13
  transform = transforms.Compose([
14
  transforms.Resize(256),
15
  transforms.CenterCrop(224),
 
17
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18
  ])
19
 
20
+ # Download the ImageNet class labels (you can replace this with your own if needed)
21
+ LABELS_URL = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
22
+ class_idx = requests.get(LABELS_URL).json()
23
+ idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]
24
+
25
+ # Function to perform image inference
26
+ def predict_image(image):
27
+ image = Image.open(BytesIO(image)).convert("RGB")
28
+ image = transform(image).unsqueeze(0)
29
 
30
  # Perform inference
31
  with torch.no_grad():
32
+ output = model(image)
33
+
34
+ # Get the predicted label
35
+ _, predicted_class = torch.max(output, 1)
36
+ label = idx2label[predicted_class.item()]
37
+ return label
38
 
39
  # Gradio interface
40
+ with gr.Interface(fn=predict_image,
41
+ inputs=gr.inputs.Image(type="bytes"),
42
+ outputs=gr.outputs.Textbox()) as demo:
43
+ demo.launch(debug=True)