NemesisAlm commited on
Commit
19be735
·
1 Parent(s): 823929d

1st commit MCP server

Browse files
Files changed (2) hide show
  1. app.py +59 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from PIL import Image
5
+
6
+ device="cpu"
7
+
8
+ processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
9
+ model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
10
+
11
+ LIST_COUNTRIES = ["Albania", "Andorra", "Argentina", "Australia", "Austria", "Bangladesh", "Belgium", "Bermuda", "Bhutan", "Bolivia", "Botswana", "Brazil", "Bulgaria", "Cambodia", "Canada", "Chile", "China", "Colombia", "Croatia", "Czech Republic", "Denmark", "Dominican Republic", "Ecuador", "Estonia", "Finland", "France", "Germany", "Ghana", "Greece", "Greenland", "Guam", "Guatemala", "Hungary", "Iceland", "India", "Indonesia", "Ireland", "Israel", "Italy", "Japan", "Jordan", "Kenya", "Kyrgyzstan", "Laos", "Latvia", "Lesotho", "Lithuania", "Luxembourg", "Macedonia", "Madagascar", "Malaysia", "Malta", "Mexico", "Monaco", "Mongolia", "Montenegro", "Netherlands", "New Zealand", "Nigeria", "Norway", "Pakistan", "Palestine", "Peru", "Philippines", "Poland", "Portugal", "Puerto Rico", "Romania", "Russia", "Rwanda", "Senegal", "Serbia", "Singapore", "Slovakia", "Slovenia", "South Africa", "South Korea", "Spain", "Sri Lanka", "Swaziland", "Sweden", "Switzerland", "Taiwan", "Thailand", "Tunisia", "Turkey", "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay"]
12
+
13
+
14
+ def classify_country(image_path: str) -> dict:
15
+ """
16
+ Classify street view image to geolocate them at country-level.
17
+
18
+ Args:
19
+ image_path (str): filename of the uploaded image
20
+
21
+ Returns:
22
+ dict: A dictionary containing scores for the top-10 countries. Ex: {'United Arab Emirates': 0.12621667981147766, 'Denmark': 0.12389234453439713, 'United States': 0.1217150092124939, 'Estonia': 0.09820151329040527, 'Andorra': 0.07836015522480011, 'Belgium': 0.049682170152664185, 'Taiwan': 0.033752717077732086, 'Australia': 0.023750243708491325, 'Guatemala': 0.020587120205163956, 'Netherlands': 0.016420230269432068}
23
+ """
24
+ image=Image.open(image_path)
25
+ inputs = processor(text=LIST_COUNTRIES, images=image, return_tensors="pt", padding=True).to(device)
26
+ with torch.no_grad():
27
+ outputs = model(**inputs)
28
+ logits_per_image = outputs.logits_per_image
29
+ prediction = logits_per_image.softmax(dim=1)
30
+ confidences = {LIST_COUNTRIES[i]: float(prediction[0][i].item()) for i in range(len(LIST_COUNTRIES))}
31
+ top_10 = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:10])
32
+ return top_10
33
+
34
+ theme = gr.themes.Citrus()
35
+
36
+ with gr.Blocks(title="Geoloc-mcp", theme=theme, fill_height=True, analytics_enabled=False) as demo:
37
+ gr.HTML("<h1>Geolocation of images</h1>")
38
+ gr.HTML(f"<p>List countries supported:<br><i>{', '.join(LIST_COUNTRIES)}</i></p>")
39
+ with gr.Row():
40
+ with gr.Column():
41
+ gr.HTML("<h2>Input a street view image</h2>")
42
+ input_image = gr.Image(show_label=False, type="filepath", height="400")
43
+ image_country_button = gr.Button("🌎 Geolocate", variant="primary")
44
+ with gr.Column():
45
+ gr.HTML("<h2>Classification score for top-10 countries</h2>")
46
+ result = gr.Label(num_top_classes=10, label="Top 10 countries")
47
+
48
+ image_country_button.click(
49
+ fn=classify_country,
50
+ inputs=[input_image],
51
+ outputs=[result],
52
+ api_name="classify_country",
53
+ concurrency_limit=10
54
+ )
55
+
56
+
57
+ demo.launch(mcp_server=True)
58
+
59
+
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio[mcp]==5.32.1
2
+ torch==2.7.1
3
+ transformers==4.52.4