import numpy as np # linear algebra import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) from huggingface_hub import snapshot_download from datasets import load_dataset from gensim.models import FastText from s2sphere import CellId, Cell, LatLng from collections import defaultdict import folium from folium import Map import gradio as gr from gradio_folium import Folium from sklearn.cluster import KMeans def extract_restaurant_embeddings(model, processed_df): """ Extract the embeddings for all restaurants """ unique_restaurants = processed_df['res_cell_id'].unique() restaurant_embeddings = {} for restaurant_id in unique_restaurants: token = str(restaurant_id) # No prefix, just the cell ID try: embedding = model.wv[token] restaurant_embeddings[restaurant_id] = embedding except KeyError: print(f"Warning: Restaurant {restaurant_id} not found in vocabulary") return restaurant_embeddings def cluster_embeddings(restaurant_embeddings, algo): restaurant_ids = list(restaurant_embeddings.keys()) embedding_matrix = np.array([restaurant_embeddings[res_id] for res_id in restaurant_ids]) labels = algo.fit_predict(embedding_matrix) restaurant_clusters = dict(zip(restaurant_ids, labels)) return restaurant_clusters def s2_cell_to_geojson(cell_id_token_or_int): # Convert to CellId cell_id = CellId.from_token(str(cell_id_token_or_int)) if isinstance(cell_id_token_or_int, str) else CellId(cell_id_token_or_int) cell = Cell(cell_id) # Get cell corner coordinates coords = [] for i in range(4): vertex = cell.get_vertex(i) latlng = LatLng.from_point(vertex) coords.append([latlng.lng().degrees, latlng.lat().degrees]) # GeoJSON uses [lng, lat] coords.append(coords[0]) # Close the polygon # Build GeoJSON geojson = { "type": "Feature", "geometry": { "type": "Polygon", "coordinates": [coords] }, "properties": { "cell_id": str(cell_id), "level": cell_id.level() } } return geojson def map_cluster_to_restaurants(restaurant_clusters): # Reverse mapping: cluster_id → list of restaurant_ids cluster_to_restaurants = defaultdict(list) for res_id, cluster_id in restaurant_clusters.items(): cluster_to_restaurants[cluster_id].append(res_id) return cluster_to_restaurants def get_cluster_jsons(cluster_to_restaurants): clusters_jsons = [] for cid, res_ids in cluster_to_restaurants.items(): features = [] for cell_id in res_ids: try: feature = s2_cell_to_geojson(cell_id) features.append(feature) except Exception as e: print(f"Error converting {cell_id}: {e}") # Build GeoJSON FeatureCollection geojson = { "type": "FeatureCollection", "features": features } clusters_jsons.append(geojson) return clusters_jsons def visualise_on_map(jsons): # Create map (you can center it later using a known location or one of the features) m = Map(location=[12.935656, 77.543204], zoom_start=12) # Loop through all cluster GeoJSONs and add them to the map for i, geojson in enumerate(jsons): try: folium.GeoJson( geojson, name=f"Cluster {i}", tooltip=f"Cluster {i}", style_function=lambda feature, color=f"#{i*123456%0xFFFFFF:06x}": { "fillColor": color, "color": color, "weight": 1, "fillOpacity": 0.4, }, ).add_to(m) except Exception as e: print(f"Failed to add cluster {i}: {e}") # Optional: Add a layer control to toggle clusters folium.LayerControl().add_to(m) return m REPO_ID = "ankush-003/fastCell" dataset = load_dataset("ankush-003/Cells_Data") df = dataset['train'].to_pandas() model = FastText.load( "cell_embedddings_model" ) restaurant_embeddings = extract_restaurant_embeddings(model, df) clusters_jsons = None def run_clustering(num_clusters, clusters_to_display): global clusters_jsons kmeans = KMeans(n_clusters=num_clusters, random_state=42) restaurant_clusters = cluster_embeddings(restaurant_embeddings, kmeans) df['cluster'] = df['res_cell_id'].map(restaurant_clusters) # Count restaurants per cluster cluster_sizes = df['cluster'].value_counts().sort_index() avg_size = cluster_sizes.mean() min_size = cluster_sizes.min() max_size = cluster_sizes.max() analysis = f""" ## Clustering Analysis (K={num_clusters}) - Total restaurants: {len(df)} - Number of clusters: {num_clusters} - Average restaurants per cluster: {avg_size:.1f} - Smallest cluster size: {min_size} - Largest cluster size: {max_size} - Empty clusters: {num_clusters - len(cluster_sizes)} """ c_to_r = map_cluster_to_restaurants(restaurant_clusters) clusters_jsons = get_cluster_jsons(c_to_r) if clusters_to_display > len(clusters_jsons): clusters_to_display = len(clusters_jsons) # Show map m = visualise_on_map(clusters_jsons[:clusters_to_display]) return analysis, m def update_display(clusters_to_display): global clusters_jsons if clusters_jsons is None: return Map(location=[12.935656, 77.543204], zoom_start=12) # Ensure we don't try to show more clusters than exist if clusters_to_display > len(clusters_jsons): clusters_to_display = len(clusters_jsons) # Create map visualization with selected number of clusters m = visualise_on_map(clusters_jsons[:clusters_to_display]) return m # Create Gradio interface with gr.Blocks(title="Restaurant Clustering Tool") as app: gr.Markdown("# Restaurant Cell Embeddings Clustering Analysis") with gr.Row(): with gr.Column(scale=1): num_clusters_input = gr.Slider( minimum=2, maximum=3460, value=300, step=1, label="Total Number of Clusters (K)" ) display_clusters_input = gr.Slider( minimum=1, maximum=3460, value=10, step=1, label="Number of Clusters to Display" ) with gr.Row(): cluster_btn = gr.Button("Run Clustering") with gr.Row(): output_text = gr.Markdown() with gr.Row(): output_plot = Folium(value=Map(location=[12.935656, 77.543204], zoom_start=12), height=1000) cluster_btn.click( fn=run_clustering, inputs=[num_clusters_input, display_clusters_input], outputs=[output_text, output_plot] ) display_clusters_input.change(update_display, inputs=[display_clusters_input], outputs=[output_plot]) gr.Markdown(""" ## About this app This app demonstrates K-means clustering on restaurant cell embeddings. The algorithm groups similar restaurants together based on cell embeddings. ### How to use: 1. Adjust the number of clusters using the slider 2. Click "Run Clustering" to see the results 3. Analyze the visualization and metrics """) if __name__ == "__main__": app.launch()