Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import CLIPProcessor, CLIPModel | |
from PIL import Image | |
device="cpu" | |
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") | |
model = CLIPModel.from_pretrained("geolocal/StreetCLIP") | |
LIST_COUNTRIES = ["Albania", "Andorra", "Argentina", "Australia", "Austria", "Bangladesh", "Belgium", "Bermuda", "Bhutan", "Bolivia", "Botswana", "Brazil", "Bulgaria", "Cambodia", "Canada", "Chile", "China", "Colombia", "Croatia", "Czech Republic", "Denmark", "Dominican Republic", "Ecuador", "Estonia", "Finland", "France", "Germany", "Ghana", "Greece", "Greenland", "Guam", "Guatemala", "Hungary", "Iceland", "India", "Indonesia", "Ireland", "Israel", "Italy", "Japan", "Jordan", "Kenya", "Kyrgyzstan", "Laos", "Latvia", "Lesotho", "Lithuania", "Luxembourg", "Macedonia", "Madagascar", "Malaysia", "Malta", "Mexico", "Monaco", "Mongolia", "Montenegro", "Netherlands", "New Zealand", "Nigeria", "Norway", "Pakistan", "Palestine", "Peru", "Philippines", "Poland", "Portugal", "Puerto Rico", "Romania", "Russia", "Rwanda", "Senegal", "Serbia", "Singapore", "Slovakia", "Slovenia", "South Africa", "South Korea", "Spain", "Sri Lanka", "Swaziland", "Sweden", "Switzerland", "Taiwan", "Thailand", "Tunisia", "Turkey", "Uganda", "Ukraine", "United Arab Emirates", "United Kingdom", "United States", "Uruguay"] | |
def classify_country(image_path: str) -> dict: | |
""" | |
Classify street view image to geolocate them at country-level. | |
Args: | |
image_path (str): filename of the uploaded image | |
Returns: | |
dict: A dictionary containing scores for the top-10 countries. Ex: {'United Arab Emirates': 0.12621667981147766, 'Denmark': 0.12389234453439713, 'United States': 0.1217150092124939, 'Estonia': 0.09820151329040527, 'Andorra': 0.07836015522480011, 'Belgium': 0.049682170152664185, 'Taiwan': 0.033752717077732086, 'Australia': 0.023750243708491325, 'Guatemala': 0.020587120205163956, 'Netherlands': 0.016420230269432068} | |
""" | |
image=Image.open(image_path) | |
inputs = processor(text=LIST_COUNTRIES, images=image, return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits_per_image = outputs.logits_per_image | |
prediction = logits_per_image.softmax(dim=1) | |
confidences = {LIST_COUNTRIES[i]: float(prediction[0][i].item()) for i in range(len(LIST_COUNTRIES))} | |
top_10 = dict(sorted(confidences.items(), key=lambda item: item[1], reverse=True)[:10]) | |
return top_10 | |
theme = gr.themes.Citrus() | |
with gr.Blocks(title="Geoloc-mcp", theme=theme, fill_height=True, analytics_enabled=False) as demo: | |
gr.HTML("<h1>Geolocation of images</h1>") | |
gr.HTML(f"<p>List countries supported:<br><i>{', '.join(LIST_COUNTRIES)}</i></p>") | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML("<h2>Input a street view image</h2>") | |
input_image = gr.Image(show_label=False, type="filepath", height="400") | |
image_country_button = gr.Button("π Geolocate", variant="primary") | |
with gr.Column(): | |
gr.HTML("<h2>Classification score for top-10 countries</h2>") | |
result = gr.Label(num_top_classes=10, label="Top 10 countries") | |
image_country_button.click( | |
fn=classify_country, | |
inputs=[input_image], | |
outputs=[result], | |
api_name="classify_country", | |
concurrency_limit=10 | |
) | |
demo.launch(mcp_server=True) | |