Spaces:
Sleeping
Sleeping
import numpy as np # linear algebra | |
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | |
from huggingface_hub import snapshot_download | |
from datasets import load_dataset | |
from gensim.models import FastText | |
from s2sphere import CellId, Cell, LatLng | |
from collections import defaultdict | |
import folium | |
from folium import Map | |
import gradio as gr | |
from gradio_folium import Folium | |
from sklearn.cluster import KMeans | |
def extract_restaurant_embeddings(model, processed_df): | |
""" | |
Extract the embeddings for all restaurants | |
""" | |
unique_restaurants = processed_df['res_cell_id'].unique() | |
restaurant_embeddings = {} | |
for restaurant_id in unique_restaurants: | |
token = str(restaurant_id) # No prefix, just the cell ID | |
try: | |
embedding = model.wv[token] | |
restaurant_embeddings[restaurant_id] = embedding | |
except KeyError: | |
print(f"Warning: Restaurant {restaurant_id} not found in vocabulary") | |
return restaurant_embeddings | |
def cluster_embeddings(restaurant_embeddings, algo): | |
restaurant_ids = list(restaurant_embeddings.keys()) | |
embedding_matrix = np.array([restaurant_embeddings[res_id] for res_id in restaurant_ids]) | |
labels = algo.fit_predict(embedding_matrix) | |
restaurant_clusters = dict(zip(restaurant_ids, labels)) | |
return restaurant_clusters | |
def s2_cell_to_geojson(cell_id_token_or_int): | |
# Convert to CellId | |
cell_id = CellId.from_token(str(cell_id_token_or_int)) if isinstance(cell_id_token_or_int, str) else CellId(cell_id_token_or_int) | |
cell = Cell(cell_id) | |
# Get cell corner coordinates | |
coords = [] | |
for i in range(4): | |
vertex = cell.get_vertex(i) | |
latlng = LatLng.from_point(vertex) | |
coords.append([latlng.lng().degrees, latlng.lat().degrees]) # GeoJSON uses [lng, lat] | |
coords.append(coords[0]) # Close the polygon | |
# Build GeoJSON | |
geojson = { | |
"type": "Feature", | |
"geometry": { | |
"type": "Polygon", | |
"coordinates": [coords] | |
}, | |
"properties": { | |
"cell_id": str(cell_id), | |
"level": cell_id.level() | |
} | |
} | |
return geojson | |
def map_cluster_to_restaurants(restaurant_clusters): | |
# Reverse mapping: cluster_id → list of restaurant_ids | |
cluster_to_restaurants = defaultdict(list) | |
for res_id, cluster_id in restaurant_clusters.items(): | |
cluster_to_restaurants[cluster_id].append(res_id) | |
return cluster_to_restaurants | |
def get_cluster_jsons(cluster_to_restaurants): | |
clusters_jsons = [] | |
for cid, res_ids in cluster_to_restaurants.items(): | |
features = [] | |
for cell_id in res_ids: | |
try: | |
feature = s2_cell_to_geojson(cell_id) | |
features.append(feature) | |
except Exception as e: | |
print(f"Error converting {cell_id}: {e}") | |
# Build GeoJSON FeatureCollection | |
geojson = { | |
"type": "FeatureCollection", | |
"features": features | |
} | |
clusters_jsons.append(geojson) | |
return clusters_jsons | |
def visualise_on_map(jsons): | |
# Create map (you can center it later using a known location or one of the features) | |
m = Map(location=[12.935656, 77.543204], zoom_start=12) | |
# Loop through all cluster GeoJSONs and add them to the map | |
for i, geojson in enumerate(jsons): | |
try: | |
folium.GeoJson( | |
geojson, | |
name=f"Cluster {i}", | |
tooltip=f"Cluster {i}", | |
style_function=lambda feature, color=f"#{i*123456%0xFFFFFF:06x}": { | |
"fillColor": color, | |
"color": color, | |
"weight": 1, | |
"fillOpacity": 0.4, | |
}, | |
).add_to(m) | |
except Exception as e: | |
print(f"Failed to add cluster {i}: {e}") | |
# Optional: Add a layer control to toggle clusters | |
folium.LayerControl().add_to(m) | |
return m | |
REPO_ID = "ankush-003/fastCell" | |
dataset = load_dataset("ankush-003/Cells_Data") | |
df = dataset['train'].to_pandas() | |
model = FastText.load( | |
"cell_embedddings_model" | |
) | |
restaurant_embeddings = extract_restaurant_embeddings(model, df) | |
clusters_jsons = None | |
def run_clustering(num_clusters, clusters_to_display): | |
global clusters_jsons | |
kmeans = KMeans(n_clusters=num_clusters, random_state=42) | |
restaurant_clusters = cluster_embeddings(restaurant_embeddings, kmeans) | |
df['cluster'] = df['res_cell_id'].map(restaurant_clusters) | |
# Count restaurants per cluster | |
cluster_sizes = df['cluster'].value_counts().sort_index() | |
avg_size = cluster_sizes.mean() | |
min_size = cluster_sizes.min() | |
max_size = cluster_sizes.max() | |
analysis = f""" | |
## Clustering Analysis (K={num_clusters}) | |
- Total restaurants: {len(df)} | |
- Number of clusters: {num_clusters} | |
- Average restaurants per cluster: {avg_size:.1f} | |
- Smallest cluster size: {min_size} | |
- Largest cluster size: {max_size} | |
- Empty clusters: {num_clusters - len(cluster_sizes)} | |
""" | |
c_to_r = map_cluster_to_restaurants(restaurant_clusters) | |
clusters_jsons = get_cluster_jsons(c_to_r) | |
if clusters_to_display > len(clusters_jsons): | |
clusters_to_display = len(clusters_jsons) | |
# Show map | |
m = visualise_on_map(clusters_jsons[:clusters_to_display]) | |
return analysis, m | |
def update_display(clusters_to_display): | |
global clusters_jsons | |
if clusters_jsons is None: | |
return Map(location=[12.935656, 77.543204], zoom_start=12) | |
# Ensure we don't try to show more clusters than exist | |
if clusters_to_display > len(clusters_jsons): | |
clusters_to_display = len(clusters_jsons) | |
# Create map visualization with selected number of clusters | |
m = visualise_on_map(clusters_jsons[:clusters_to_display]) | |
return m | |
# Create Gradio interface | |
with gr.Blocks(title="Restaurant Clustering Tool") as app: | |
gr.Markdown("# Restaurant Cell Embeddings Clustering Analysis") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
num_clusters_input = gr.Slider( | |
minimum=2, | |
maximum=3460, | |
value=300, | |
step=1, | |
label="Total Number of Clusters (K)" | |
) | |
display_clusters_input = gr.Slider( | |
minimum=1, | |
maximum=3460, | |
value=10, | |
step=1, | |
label="Number of Clusters to Display" | |
) | |
with gr.Row(): | |
cluster_btn = gr.Button("Run Clustering") | |
with gr.Row(): | |
output_text = gr.Markdown() | |
with gr.Row(): | |
output_plot = Folium(value=Map(location=[12.935656, 77.543204], zoom_start=12), height=1000) | |
cluster_btn.click( | |
fn=run_clustering, | |
inputs=[num_clusters_input, display_clusters_input], | |
outputs=[output_text, output_plot] | |
) | |
display_clusters_input.change(update_display, inputs=[display_clusters_input], outputs=[output_plot]) | |
gr.Markdown(""" | |
## About this app | |
This app demonstrates K-means clustering on restaurant cell embeddings. The algorithm groups similar restaurants together based on cell embeddings. | |
### How to use: | |
1. Adjust the number of clusters using the slider | |
2. Click "Run Clustering" to see the results | |
3. Analyze the visualization and metrics | |
""") | |
if __name__ == "__main__": | |
app.launch() |