dominguezdaniel commited on
Commit
3da23d5
·
verified ·
1 Parent(s): 29a69bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -9
app.py CHANGED
@@ -1,14 +1,73 @@
 
 
1
  import gradio as gr
2
- from gradio import inputs
3
 
4
- titulo = "Clasificador de imagenes"
5
- descripcion = "Este es un Demo que hace un reconocimiento de imagenes y clasifica según lo que el modelo de reconocimiento de imagenes de huggingface/google/vit-base-patch16-224 encuentre."
6
 
7
- modelo = "huggingface/google/vit-base-patch16-224"
8
- entrada = gr.inputs.Image(label="Carga una imagen aquí")
9
 
10
- # Assuming you're directly loading a model, which might not work as explained earlier.
11
- # You may need to adjust this part according to the actual way to load and use models with Gradio.
12
- iface = gr.Interface.load(modelo, inputs=entrada, title=titulo, description=descripcion)
 
 
13
 
14
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
  import gradio as gr
 
4
 
5
+ from PIL import Image
6
+ from torchvision import transforms
7
 
 
 
8
 
9
+ """
10
+ Built following:
11
+ https://huggingface.co/spaces/pytorch/ResNet/tree/main
12
+ https://www.gradio.app/image_classification_in_pytorch/
13
+ """
14
 
15
+ # Get classes list
16
+ os.system("wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt")
17
+
18
+ # Load PyTorch model
19
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
20
+ model.eval()
21
+
22
+ # Download an example image from the pytorch website
23
+ torch.hub.download_url_to_file("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
24
+
25
+ # Inference!
26
+ def inference(input_image):
27
+ preprocess = transforms.Compose([
28
+ transforms.Resize(256),
29
+ transforms.CenterCrop(224),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
+ ])
33
+ input_tensor = preprocess(input_image)
34
+ input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
35
+
36
+ # Move the input and model to GPU for speed if available
37
+ if torch.cuda.is_available():
38
+ input_batch = input_batch.to('cuda')
39
+ model.to('cuda')
40
+
41
+ with torch.no_grad():
42
+ output = model(input_batch)
43
+ # The output has unnormalized scores. To get probabilities, you can run a softmax on it.
44
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
45
+
46
+ # Read the categories
47
+ with open("imagenet_classes.txt", "r") as f:
48
+ categories = [s.strip() for s in f.readlines()]
49
+ # Show top categories per image
50
+ top5_prob, top5_catid = torch.topk(probabilities, 5)
51
+ result = {}
52
+ for i in range(top5_prob.size(0)):
53
+ result[categories[top5_catid[i]]] = top5_prob[i].item()
54
+ return result
55
+
56
+ # Define ins outs placeholders
57
+ inputs = gr.inputs.Image(type='pil')
58
+ outputs = gr.outputs.Label(type="confidences",num_top_classes=5)
59
+
60
+ # Define style
61
+ title = "Image Recognition Demo"
62
+ description = "This is a prototype application which demonstrates how artifical intelligence based systems can recognize what object(s) is present in an image. This fundamental task in computer vision known as `Image Classification` has applications stretching from autonomous vehicles to medical imaging. To use it, simply upload your image, or click one of the examples images to load them, which I took at <a href='https://espacepourlavie.ca/en/biodome' target='_blank'>Montréal Biodôme</a>! Read more at the links below."
63
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/1512.03385' target='_blank'>Deep Residual Learning for Image Recognition</a> | <a href='https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py' target='_blank'>Github Repo</a></p>"
64
+
65
+ # Run inference
66
+ gr.Interface(inference,
67
+ inputs,
68
+ outputs,
69
+ examples=["example1.jpg", "example2.jpg"],
70
+ title=title,
71
+ description=description,
72
+ article=article,
73
+ analytics_enabled=False).launch()