import gradio as gr import torch from transformers import CLIPProcessor, CLIPModel from PIL import Image device="cpu" processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") model = CLIPModel.from_pretrained("geolocal/StreetCLIP") LIST_COUNTRIES = ["Albania", "Andorra", "Argentina", "Australia", "Austria", "Bangladesh", "Belgium", "Bermuda", "Bhutan", "Bolivia", "Botswana", "Brazil", "Bulgaria", "Cambodia", "Canada", "Chile", "China", "Colombia", "Croatia", "Czech Republic", "Denmark", "Dominican Republic", "Ecuador", "Estonia", "Finland", "France", "Germany", "Ghana", "Greece", "Greenland", "Guam", "Guatemala", "Hungary", "Iceland", "India", "Indonesia", "Ireland", "Israel", "Italy", "Japan", "Jordan", "Kenya", "Kyrgyzstan", "Laos", "Latvia", "Lesotho", "Lithuania", "Luxembourg", "Macedonia", "Madagascar", "Malaysia", "Malta", "Mexico", "Monaco", "Mongolia", "Montenegro", "Netherlands", "New Zealand", "Nigeria", "Norway", "Pakistan", "Palestine", "Peru", "Philippines", "Poland", "Portugal", "Puerto Rico", "Romania", "Russia", "Rwanda", "Senegal", "Serbia", "Singapore", "Slovakia", "Slovenia", "South Africa", "South Korea", "Spain", "Sri Lanka", "Swaziland", "Sweden", "Switzerland", "Taiwan", "Thailand", "Tunisia", "Turkey", "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay"] def classify_country(image_path: str) -> dict: """ Classify street view image to geolocate them at country-level. Args: image_path (str): filename of the uploaded image Returns: dict: A dictionary containing scores for the top-10 countries. Ex: {'United Arab Emirates': 0.12621667981147766, 'Denmark': 0.12389234453439713, 'United States': 0.1217150092124939, 'Estonia': 0.09820151329040527, 'Andorra': 0.07836015522480011, 'Belgium': 0.049682170152664185, 'Taiwan': 0.033752717077732086, 'Australia': 0.023750243708491325, 'Guatemala': 0.020587120205163956, 'Netherlands': 0.016420230269432068} """ image=Image.open(image_path) inputs = processor(text=LIST_COUNTRIES, images=image, return_tensors="pt", padding=True).to(device) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image prediction = logits_per_image.softmax(dim=1) confidences = {LIST_COUNTRIES[i]: float(prediction[0][i].item()) for i in range(len(LIST_COUNTRIES))} top_10 = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:10]) return top_10 theme = gr.themes.Citrus() with gr.Blocks(title="Geoloc-mcp", theme=theme, fill_height=True, analytics_enabled=False) as demo: gr.HTML("

Geolocation of images

") gr.HTML(f"

List countries supported:
{', '.join(LIST_COUNTRIES)}

") with gr.Row(): with gr.Column(): gr.HTML("

Input a street view image

") input_image = gr.Image(show_label=False, type="filepath", height="400") image_country_button = gr.Button("🌎 Geolocate", variant="primary") with gr.Column(): gr.HTML("

Classification score for top-10 countries

") result = gr.Label(num_top_classes=10, label="Top 10 countries") image_country_button.click( fn=classify_country, inputs=[input_image], outputs=[result], api_name="classify_country", concurrency_limit=10 ) demo.launch(mcp_server=True)