File size: 6,252 Bytes
1b1d8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import pandas as pd
import numpy as np
from sklearn.manifold import TSNE
import json
import base64

def generate_tsne_embedding(input_file, output_file):
    # Load the Parquet file
    df = pd.read_parquet(input_file)
    
    # Extract embeddings and convert to numpy array
    embeddings = np.array(df['embedding'].tolist())
    
    # Perform t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    tsne_results = tsne.fit_transform(embeddings)
    
    # Prepare output data
    output_data = []
    for i, (x, y) in enumerate(tsne_results):
        image_base64 = base64.b64encode(df['image'][i]).decode('utf-8')
        output_data.append({
            'x': float(x),
            'y': float(y),
            'image': image_base64
        })
    
    # Save results to JSON file
    with open(output_file, 'w') as f:
        json.dump(output_data, f)

## ----------------------------
## Dash app
## ----------------------------

import os
import base64
import json
import numpy as np
from dash import dcc, html, Input, Output, no_update, Dash
import numpy as np
from sklearn.cluster import KMeans
from scipy.spatial.distance import cdist
import plotly.graph_objects as go
from PIL import Image
import random
import socket

def find_free_port():
    while True:
        port = random.randint(49152, 65535)  # Use dynamic/private port range
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            try:
                s.bind(('', port))
                return port
            except OSError:
                pass

def create_dash_app(fig, images):
    app = Dash(__name__)

    app.layout = html.Div(
        className="container",
        children=[
            dcc.Graph(id="graph", figure=fig, clear_on_unhover=True),
            dcc.Tooltip(id="graph-tooltip", direction='bottom'),
        ],
    )

    @app.callback(
        Output("graph-tooltip", "show"),
        Output("graph-tooltip", "bbox"),
        Output("graph-tooltip", "children"),
        Input("graph", "hoverData"),
    )
    def display_hover(hoverData):
        if hoverData is None:
            return False, no_update, no_update

        hover_data = hoverData["points"][0]
        bbox = hover_data["bbox"]
        num = hover_data["pointNumber"]

        image_base64 = images[num]
        children = [
            html.Div([
                html.Img(
                    src=f"data:image/jpeg;base64,{image_base64}",
                    style={"width": "200px", 
                           "height": "200px", 
                           'display': 'block', 'margin': '0 auto'},
                ),
            ])
        ]

        return True, bbox, children

    return app

def perform_kmeans(data, k=20):
    # Extract x, y coordinates
    coords = np.array([[point['x'], point['y']] for point in data])
    
    # Perform k-means clustering
    kmeans = KMeans(n_clusters=k, random_state=42)
    kmeans.fit(coords)
    
    return kmeans

def find_nearest_images(data, kmeans):
    coords = np.array([[point['x'], point['y']] for point in data])
    images = [point['image'] for point in data]
    
    # Calculate distances to cluster centers
    distances = cdist(coords, kmeans.cluster_centers_, metric='euclidean')
    
    # Find the index of the nearest point for each cluster
    nearest_indices = distances.argmin(axis=0)
    
    # Get the images nearest to each cluster center
    nearest_images = [images[i] for i in nearest_indices]
    
    return nearest_images, kmeans.cluster_centers_

def create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title):
    # Extract x, y coordinates
    x = [point['x'] for point in data]
    y = [point['y'] for point in data]
    images = [point['image'] for point in data]

    # Determine the range for both axes
    max_range = max(max(x) - min(x), max(y) - min(y)) / 2
    center_x = (max(x) + min(x)) / 2
    center_y = (max(y) + min(y)) / 2

    # Create the scatter plot
    fig = go.Figure()

    # Add data points
    fig.add_trace(go.Scatter(
        x=x,
        y=y,
        mode='markers',
        marker=dict(
            size=5,
            color=kmeans_result.labels_,
            colorscale='Viridis',
            showscale=False
        ),
        name='Data Points'
    ))

    # Add cluster centers and images

    fig.update_layout(
        title=title,
        width=1000, height=1000,
        xaxis=dict(
            range=[center_x - max_range, center_x + max_range],
            scaleanchor="y",
            scaleratio=1,
        ),
        yaxis=dict(
            range=[center_y - max_range, center_y + max_range],
        ),
        showlegend=False,
    )

    fig.update_traces(
        hoverinfo="none",
        hovertemplate=None,
    )
    # Add images
    for i, (cx, cy) in enumerate(cluster_centers):
        fig.add_layout_image(
            dict(
                source=f"data:image/jpg;base64,{nearest_images[i]}",
                x=cx,
                y=cy,
                xref="x",
                yref="y",
                sizex=10,
                sizey=10,
                sizing="contain",
                opacity=1,
                layer="below"
            )
        )
        
    # Remove x and y axes ticks
    fig.update_layout(xaxis=dict(visible=False), yaxis=dict(visible=False))

    return fig, images

def make_dash_kmeans(data, title, k=40):
    kmeans_result = perform_kmeans(data, k=k)
    nearest_images, cluster_centers = find_nearest_images(data, kmeans_result)
    fig, images = create_dash_fig(data, kmeans_result, nearest_images, cluster_centers, title)
    app = create_dash_app(fig, images)
    port = find_free_port()
    print(f"Serving on http://127.0.0.1:{port}/")
    print(f"To serve this over the Internet, run `ngrok http {port}`")
    app.run_server(port=port)
    return app

if __name__ == "__main__":
    
    dataset_folder = os.path.dirname('./')
    name = "style"
    image_embedding_path = os.path.join(dataset_folder, f"processed_dataset.parquet")
    tsne_path = os.path.join(dataset_folder, f"processed_dataset.json")

    generate_tsne_embedding(image_embedding_path, tsne_path)
    with open(tsne_path, "r") as f:
        data = json.load(f)

    make_dash_kmeans(data, name, k=40)