cell_cluster / app.py
ankush-003's picture
Update app.py
c0a6c3c verified
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
from huggingface_hub import snapshot_download
from datasets import load_dataset
from gensim.models import FastText
from s2sphere import CellId, Cell, LatLng
from collections import defaultdict
import folium
from folium import Map
import gradio as gr
from gradio_folium import Folium
from sklearn.cluster import KMeans
def extract_restaurant_embeddings(model, processed_df):
"""
Extract the embeddings for all restaurants
"""
unique_restaurants = processed_df['res_cell_id'].unique()
restaurant_embeddings = {}
for restaurant_id in unique_restaurants:
token = str(restaurant_id) # No prefix, just the cell ID
try:
embedding = model.wv[token]
restaurant_embeddings[restaurant_id] = embedding
except KeyError:
print(f"Warning: Restaurant {restaurant_id} not found in vocabulary")
return restaurant_embeddings
def cluster_embeddings(restaurant_embeddings, algo):
restaurant_ids = list(restaurant_embeddings.keys())
embedding_matrix = np.array([restaurant_embeddings[res_id] for res_id in restaurant_ids])
labels = algo.fit_predict(embedding_matrix)
restaurant_clusters = dict(zip(restaurant_ids, labels))
return restaurant_clusters
def s2_cell_to_geojson(cell_id_token_or_int):
# Convert to CellId
cell_id = CellId.from_token(str(cell_id_token_or_int)) if isinstance(cell_id_token_or_int, str) else CellId(cell_id_token_or_int)
cell = Cell(cell_id)
# Get cell corner coordinates
coords = []
for i in range(4):
vertex = cell.get_vertex(i)
latlng = LatLng.from_point(vertex)
coords.append([latlng.lng().degrees, latlng.lat().degrees]) # GeoJSON uses [lng, lat]
coords.append(coords[0]) # Close the polygon
# Build GeoJSON
geojson = {
"type": "Feature",
"geometry": {
"type": "Polygon",
"coordinates": [coords]
},
"properties": {
"cell_id": str(cell_id),
"level": cell_id.level()
}
}
return geojson
def map_cluster_to_restaurants(restaurant_clusters):
# Reverse mapping: cluster_id → list of restaurant_ids
cluster_to_restaurants = defaultdict(list)
for res_id, cluster_id in restaurant_clusters.items():
cluster_to_restaurants[cluster_id].append(res_id)
return cluster_to_restaurants
def get_cluster_jsons(cluster_to_restaurants):
clusters_jsons = []
for cid, res_ids in cluster_to_restaurants.items():
features = []
for cell_id in res_ids:
try:
feature = s2_cell_to_geojson(cell_id)
features.append(feature)
except Exception as e:
print(f"Error converting {cell_id}: {e}")
# Build GeoJSON FeatureCollection
geojson = {
"type": "FeatureCollection",
"features": features
}
clusters_jsons.append(geojson)
return clusters_jsons
def visualise_on_map(jsons):
# Create map (you can center it later using a known location or one of the features)
m = Map(location=[12.935656, 77.543204], zoom_start=12)
# Loop through all cluster GeoJSONs and add them to the map
for i, geojson in enumerate(jsons):
try:
folium.GeoJson(
geojson,
name=f"Cluster {i}",
tooltip=f"Cluster {i}",
style_function=lambda feature, color=f"#{i*123456%0xFFFFFF:06x}": {
"fillColor": color,
"color": color,
"weight": 1,
"fillOpacity": 0.4,
},
).add_to(m)
except Exception as e:
print(f"Failed to add cluster {i}: {e}")
# Optional: Add a layer control to toggle clusters
folium.LayerControl().add_to(m)
return m
REPO_ID = "ankush-003/fastCell"
dataset = load_dataset("ankush-003/Cells_Data")
df = dataset['train'].to_pandas()
model = FastText.load(
"cell_embedddings_model"
)
restaurant_embeddings = extract_restaurant_embeddings(model, df)
clusters_jsons = None
def run_clustering(num_clusters, clusters_to_display):
global clusters_jsons
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
restaurant_clusters = cluster_embeddings(restaurant_embeddings, kmeans)
df['cluster'] = df['res_cell_id'].map(restaurant_clusters)
# Count restaurants per cluster
cluster_sizes = df['cluster'].value_counts().sort_index()
avg_size = cluster_sizes.mean()
min_size = cluster_sizes.min()
max_size = cluster_sizes.max()
analysis = f"""
## Clustering Analysis (K={num_clusters})
- Total restaurants: {len(df)}
- Number of clusters: {num_clusters}
- Average restaurants per cluster: {avg_size:.1f}
- Smallest cluster size: {min_size}
- Largest cluster size: {max_size}
- Empty clusters: {num_clusters - len(cluster_sizes)}
"""
c_to_r = map_cluster_to_restaurants(restaurant_clusters)
clusters_jsons = get_cluster_jsons(c_to_r)
if clusters_to_display > len(clusters_jsons):
clusters_to_display = len(clusters_jsons)
# Show map
m = visualise_on_map(clusters_jsons[:clusters_to_display])
return analysis, m
def update_display(clusters_to_display):
global clusters_jsons
if clusters_jsons is None:
return Map(location=[12.935656, 77.543204], zoom_start=12)
# Ensure we don't try to show more clusters than exist
if clusters_to_display > len(clusters_jsons):
clusters_to_display = len(clusters_jsons)
# Create map visualization with selected number of clusters
m = visualise_on_map(clusters_jsons[:clusters_to_display])
return m
# Create Gradio interface
with gr.Blocks(title="Restaurant Clustering Tool") as app:
gr.Markdown("# Restaurant Cell Embeddings Clustering Analysis")
with gr.Row():
with gr.Column(scale=1):
num_clusters_input = gr.Slider(
minimum=2,
maximum=3460,
value=300,
step=1,
label="Total Number of Clusters (K)"
)
display_clusters_input = gr.Slider(
minimum=1,
maximum=3460,
value=10,
step=1,
label="Number of Clusters to Display"
)
with gr.Row():
cluster_btn = gr.Button("Run Clustering")
with gr.Row():
output_text = gr.Markdown()
with gr.Row():
output_plot = Folium(value=Map(location=[12.935656, 77.543204], zoom_start=12), height=1000)
cluster_btn.click(
fn=run_clustering,
inputs=[num_clusters_input, display_clusters_input],
outputs=[output_text, output_plot]
)
display_clusters_input.change(update_display, inputs=[display_clusters_input], outputs=[output_plot])
gr.Markdown("""
## About this app
This app demonstrates K-means clustering on restaurant cell embeddings. The algorithm groups similar restaurants together based on cell embeddings.
### How to use:
1. Adjust the number of clusters using the slider
2. Click "Run Clustering" to see the results
3. Analyze the visualization and metrics
""")
if __name__ == "__main__":
app.launch()