geoloc-mcp / app.py
NemesisAlm's picture
1st commit MCP server
19be735
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
device="cpu"
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
LIST_COUNTRIES = ["Albania", "Andorra", "Argentina", "Australia", "Austria", "Bangladesh", "Belgium", "Bermuda", "Bhutan", "Bolivia", "Botswana", "Brazil", "Bulgaria", "Cambodia", "Canada", "Chile", "China", "Colombia", "Croatia", "Czech Republic", "Denmark", "Dominican Republic", "Ecuador", "Estonia", "Finland", "France", "Germany", "Ghana", "Greece", "Greenland", "Guam", "Guatemala", "Hungary", "Iceland", "India", "Indonesia", "Ireland", "Israel", "Italy", "Japan", "Jordan", "Kenya", "Kyrgyzstan", "Laos", "Latvia", "Lesotho", "Lithuania", "Luxembourg", "Macedonia", "Madagascar", "Malaysia", "Malta", "Mexico", "Monaco", "Mongolia", "Montenegro", "Netherlands", "New Zealand", "Nigeria", "Norway", "Pakistan", "Palestine", "Peru", "Philippines", "Poland", "Portugal", "Puerto Rico", "Romania", "Russia", "Rwanda", "Senegal", "Serbia", "Singapore", "Slovakia", "Slovenia", "South Africa", "South Korea", "Spain", "Sri Lanka", "Swaziland", "Sweden", "Switzerland", "Taiwan", "Thailand", "Tunisia", "Turkey", "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay"]
def classify_country(image_path: str) -> dict:
"""
Classify street view image to geolocate them at country-level.
Args:
image_path (str): filename of the uploaded image
Returns:
dict: A dictionary containing scores for the top-10 countries. Ex: {'United Arab Emirates': 0.12621667981147766, 'Denmark': 0.12389234453439713, 'United States': 0.1217150092124939, 'Estonia': 0.09820151329040527, 'Andorra': 0.07836015522480011, 'Belgium': 0.049682170152664185, 'Taiwan': 0.033752717077732086, 'Australia': 0.023750243708491325, 'Guatemala': 0.020587120205163956, 'Netherlands': 0.016420230269432068}
"""
image=Image.open(image_path)
inputs = processor(text=LIST_COUNTRIES, images=image, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
prediction = logits_per_image.softmax(dim=1)
confidences = {LIST_COUNTRIES[i]: float(prediction[0][i].item()) for i in range(len(LIST_COUNTRIES))}
top_10 = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:10])
return top_10
theme = gr.themes.Citrus()
with gr.Blocks(title="Geoloc-mcp", theme=theme, fill_height=True, analytics_enabled=False) as demo:
gr.HTML("<h1>Geolocation of images</h1>")
gr.HTML(f"<p>List countries supported:<br><i>{', '.join(LIST_COUNTRIES)}</i></p>")
with gr.Row():
with gr.Column():
gr.HTML("<h2>Input a street view image</h2>")
input_image = gr.Image(show_label=False, type="filepath", height="400")
image_country_button = gr.Button("🌎 Geolocate", variant="primary")
with gr.Column():
gr.HTML("<h2>Classification score for top-10 countries</h2>")
result = gr.Label(num_top_classes=10, label="Top 10 countries")
image_country_button.click(
fn=classify_country,
inputs=[input_image],
outputs=[result],
api_name="classify_country",
concurrency_limit=10
)
demo.launch(mcp_server=True)