Avijit Ghosh commited on
Commit
961c6fe
·
1 Parent(s): 27c66d1

add cached data and preprocessing code

Browse files
Files changed (5) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +1 -0
  3. app.py +488 -497
  4. models.csv → models_processed.parquet +2 -2
  5. preprocess.py +371 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  org_to_artifacts_2l_stats.json filter=lfs diff=lfs merge=lfs -text
37
  models.csv filter=lfs diff=lfs merge=lfs -text
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  org_to_artifacts_2l_stats.json filter=lfs diff=lfs merge=lfs -text
37
  models.csv filter=lfs diff=lfs merge=lfs -text
38
+ models_processed.parquet filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,550 +1,541 @@
 
 
1
  import json
2
  import gradio as gr
3
  import pandas as pd
4
  import plotly.express as px
5
  import os
6
- import numpy as np
7
- import io
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Define pipeline tags
10
  PIPELINE_TAGS = [
11
- 'text-generation',
12
- 'text-to-image',
13
- 'text-classification',
14
- 'text2text-generation',
15
- 'audio-to-audio',
16
- 'feature-extraction',
17
- 'image-classification',
18
- 'translation',
19
- 'reinforcement-learning',
20
- 'fill-mask',
21
- 'text-to-speech',
22
- 'automatic-speech-recognition',
23
- 'image-text-to-text',
24
- 'token-classification',
25
- 'sentence-similarity',
26
- 'question-answering',
27
- 'image-feature-extraction',
28
- 'summarization',
29
- 'zero-shot-image-classification',
30
- 'object-detection',
31
- 'image-segmentation',
32
- 'image-to-image',
33
- 'image-to-text',
34
- 'audio-classification',
35
- 'visual-question-answering',
36
- 'text-to-video',
37
- 'zero-shot-classification',
38
- 'depth-estimation',
39
- 'text-ranking',
40
- 'image-to-video',
41
- 'multiple-choice',
42
- 'unconditional-image-generation',
43
- 'video-classification',
44
- 'text-to-audio',
45
- 'time-series-forecasting',
46
- 'any-to-any',
47
- 'video-text-to-text',
48
  'table-question-answering',
49
  ]
50
 
51
- # Model size categories in GB
52
- MODEL_SIZE_RANGES = {
53
- "Small (<1GB)": (0, 1),
54
- "Medium (1-5GB)": (1, 5),
55
- "Large (5-20GB)": (5, 20),
56
- "X-Large (20-50GB)": (20, 50),
57
- "XX-Large (>50GB)": (50, float('inf'))
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Filter functions for tags
61
- def is_audio_speech(row):
62
- tags = row.get("tags", [])
63
- pipeline_tag = row.get("pipeline_tag", "")
64
 
65
- return (pipeline_tag and ("audio" in pipeline_tag.lower() or "speech" in pipeline_tag.lower())) or \
66
- any("audio" in tag.lower() for tag in tags) or \
67
- any("speech" in tag.lower() for tag in tags)
68
-
69
- def is_music(row):
70
- tags = row.get("tags", [])
71
- return any("music" in tag.lower() for tag in tags)
72
-
73
- def is_robotics(row):
74
- tags = row.get("tags", [])
75
- return any("robot" in tag.lower() for tag in tags)
76
-
77
- def is_biomed(row):
78
- tags = row.get("tags", [])
79
- return any("bio" in tag.lower() for tag in tags) or \
80
- any("medic" in tag.lower() for tag in tags)
81
-
82
- def is_timeseries(row):
83
- tags = row.get("tags", [])
84
- return any("series" in tag.lower() for tag in tags)
85
-
86
- def is_science(row):
87
- tags = row.get("tags", [])
88
- return any("science" in tag.lower() and "bigscience" not in tag for tag in tags)
89
-
90
- def is_video(row):
91
- tags = row.get("tags", [])
92
- return any("video" in tag.lower() for tag in tags)
93
-
94
- def is_image(row):
95
- tags = row.get("tags", [])
96
- return any("image" in tag.lower() for tag in tags)
97
-
98
- def is_text(row):
99
- tags = row.get("tags", [])
100
- return any("text" in tag.lower() for tag in tags)
101
-
102
- # Add model size filter function
103
- def is_in_size_range(row, size_range):
104
- if size_range is None:
105
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- min_size, max_size = MODEL_SIZE_RANGES[size_range]
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- # Get model size in GB from params column
110
- if "params" in row and pd.notna(row["params"]):
111
- try:
112
- # Convert to GB (assuming params are in bytes or scientific notation)
113
- size_gb = float(row["params"]) / (1024 * 1024 * 1024)
114
- return min_size <= size_gb < max_size
115
- except (ValueError, TypeError):
116
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- return False
119
-
120
- TAG_FILTER_FUNCS = {
121
- "Audio & Speech": is_audio_speech,
122
- "Time series": is_timeseries,
123
- "Robotics": is_robotics,
124
- "Music": is_music,
125
- "Video": is_video,
126
- "Images": is_image,
127
- "Text": is_text,
128
- "Biomedical": is_biomed,
129
- "Sciences": is_science,
130
- }
 
 
 
 
 
 
 
 
 
 
131
 
132
- def extract_org_from_id(model_id):
133
- """Extract organization name from model ID"""
134
- if "/" in model_id:
135
- return model_id.split("/")[0]
136
- return "unaffiliated"
 
 
 
137
 
 
 
 
 
 
 
 
 
 
138
  def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
139
- """Process DataFrame into treemap format with filters applied"""
140
- # Create a copy to avoid modifying the original
141
  filtered_df = df.copy()
 
 
 
142
 
143
- # Apply filters
144
- if tag_filter and tag_filter in TAG_FILTER_FUNCS:
145
- filter_func = TAG_FILTER_FUNCS[tag_filter]
146
- filtered_df = filtered_df[filtered_df.apply(filter_func, axis=1)]
147
-
 
 
 
 
 
 
 
 
 
 
 
 
148
  if pipeline_filter:
149
- filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
150
-
151
- if size_filter and size_filter in MODEL_SIZE_RANGES:
152
- # Create a function to check if a model is in the size range
153
- def check_size(row):
154
- return is_in_size_range(row, size_filter)
155
-
156
- filtered_df = filtered_df[filtered_df.apply(check_size, axis=1)]
157
-
158
- # Add organization column
159
- filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
160
-
161
- # Skip organizations if specified
162
  if skip_orgs and len(skip_orgs) > 0:
163
- filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
164
-
165
- # Aggregate by organization
166
- org_totals = filtered_df.groupby("organization")[count_by].sum().reset_index()
167
- org_totals = org_totals.sort_values(by=count_by, ascending=False)
168
-
169
- # Get top organizations
170
- top_orgs = org_totals.head(top_k)["organization"].tolist()
171
-
172
- # Filter to only include models from top organizations
173
- filtered_df = filtered_df[filtered_df["organization"].isin(top_orgs)]
174
-
175
- # Prepare data for treemap
176
- treemap_data = filtered_df[["id", "organization", count_by]].copy()
177
-
178
- # Add a root node
179
- treemap_data["root"] = "models"
180
-
181
- # Ensure numeric values
182
- treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
183
-
184
  return treemap_data
185
 
186
  def create_treemap(treemap_data, count_by, title=None):
187
- """Create a Plotly treemap from the prepared data"""
188
  if treemap_data.empty:
189
- # Create an empty figure with a message
190
- fig = px.treemap(
191
- names=["No data matches the selected filters"],
192
- values=[1]
193
- )
194
- fig.update_layout(
195
- title="No data matches the selected filters",
196
- margin=dict(t=50, l=25, r=25, b=25)
197
- )
198
  return fig
199
-
200
- # Create the treemap
201
  fig = px.treemap(
202
- treemap_data,
203
- path=["root", "organization", "id"],
204
- values=count_by,
205
  title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
206
- color_discrete_sequence=px.colors.qualitative.Plotly
207
- )
208
-
209
- # Update layout
210
- fig.update_layout(
211
- margin=dict(t=50, l=25, r=25, b=25)
212
- )
213
-
214
- # Update traces for better readability
215
- fig.update_traces(
216
- textinfo="label+value+percent root",
217
- hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
218
  )
219
-
 
220
  return fig
221
 
222
- def load_models_csv():
223
- # Read the CSV file
224
- df = pd.read_csv('models.csv')
225
-
226
- # Process the tags column
227
- def process_tags(tags_str):
228
- if pd.isna(tags_str):
229
- return []
230
-
231
- # Clean the string and convert to a list
232
- tags_str = tags_str.strip("[]").replace("'", "")
233
- tags = [tag.strip() for tag in tags_str.split() if tag.strip()]
234
- return tags
235
-
236
- df['tags'] = df['tags'].apply(process_tags)
237
-
238
- # Add more sample data for better visualization
239
- add_sample_data(df)
240
-
241
- return df
242
-
243
- def add_sample_data(df):
244
- """Add more sample data to make the visualization more interesting"""
245
- # Top organizations to include
246
- orgs = ['openai', 'meta', 'google', 'microsoft', 'anthropic', 'nvidia', 'huggingface',
247
- 'deepseek-ai', 'stability-ai', 'mistralai', 'cerebras', 'databricks', 'together',
248
- 'facebook', 'amazon', 'deepmind', 'cohere', 'bigscience', 'eleutherai']
249
-
250
- # Common model name formats
251
- model_name_patterns = [
252
- "model-{size}-{version}",
253
- "{prefix}-{size}b",
254
- "{prefix}-{size}b-{variant}",
255
- "llama-{size}b-{variant}",
256
- "gpt-{variant}-{size}b",
257
- "{prefix}-instruct-{size}b",
258
- "{prefix}-chat-{size}b",
259
- "{prefix}-coder-{size}b",
260
- "stable-diffusion-{version}",
261
- "whisper-{size}",
262
- "bert-{size}-{variant}",
263
- "roberta-{size}",
264
- "t5-{size}",
265
- "{prefix}-vision-{size}b"
266
- ]
267
-
268
- # Common name parts
269
- prefixes = ["falcon", "llama", "mistral", "gpt", "phi", "gemma", "qwen", "yi", "mpt", "bloom"]
270
- sizes = ["7", "13", "34", "70", "1", "3", "7b", "13b", "70b", "8b", "2b", "1b", "0.5b", "small", "base", "large", "huge"]
271
- variants = ["chat", "instruct", "base", "v1.0", "v2", "beta", "turbo", "fast", "xl", "xxl"]
272
-
273
- # Generate sample data
274
- sample_data = []
275
- for org_idx, org in enumerate(orgs):
276
- # Create 5-10 models per organization
277
- num_models = np.random.randint(5, 11)
278
-
279
- for i in range(num_models):
280
- # Create realistic model name
281
- pattern = np.random.choice(model_name_patterns)
282
- prefix = np.random.choice(prefixes)
283
- size = np.random.choice(sizes)
284
- version = f"v{np.random.randint(1, 4)}"
285
- variant = np.random.choice(variants)
286
-
287
- model_name = pattern.format(
288
- prefix=prefix,
289
- size=size,
290
- version=version,
291
- variant=variant
292
- )
293
-
294
- model_id = f"{org}/{model_name}"
295
-
296
- # Select a realistic pipeline tag based on name
297
- if "diffusion" in model_name or "image" in model_name:
298
- pipeline_tag = np.random.choice(["text-to-image", "image-to-image", "image-segmentation"])
299
- elif "whisper" in model_name or "speech" in model_name:
300
- pipeline_tag = np.random.choice(["automatic-speech-recognition", "text-to-speech"])
301
- elif "coder" in model_name or "code" in model_name:
302
- pipeline_tag = "text-generation"
303
- elif "bert" in model_name or "roberta" in model_name:
304
- pipeline_tag = np.random.choice(["fill-mask", "text-classification", "token-classification"])
305
- elif "vision" in model_name:
306
- pipeline_tag = np.random.choice(["image-classification", "image-to-text", "visual-question-answering"])
307
- else:
308
- pipeline_tag = "text-generation" # Most common
309
-
310
- # Generate realistic tags
311
- tags = [pipeline_tag]
312
-
313
- if "text-generation" in pipeline_tag:
314
- tags.extend(["language-model", "text", "gpt", "llm"])
315
- if "instruct" in model_name:
316
- tags.append("instruction-following")
317
- if "chat" in model_name:
318
- tags.append("chat")
319
- elif "speech" in pipeline_tag:
320
- tags.extend(["audio", "speech", "voice"])
321
- elif "image" in pipeline_tag:
322
- tags.extend(["vision", "image", "diffusion"])
323
-
324
- # Add language tags
325
- if np.random.random() < 0.8: # 80% chance for English
326
- tags.append("en")
327
- if np.random.random() < 0.3: # 30% chance for multilingual
328
- tags.append("multilingual")
329
-
330
- # Generate downloads and likes (weighted by org position for variety)
331
- # Earlier orgs get more downloads to make the visualization interesting
332
- popularity_factor = (len(orgs) - org_idx) / len(orgs) # 1.0 to 0.0
333
- base_downloads = 10000 * (10 ** (2 * popularity_factor))
334
- downloads = int(base_downloads * np.random.uniform(0.3, 3.0))
335
- likes = int(downloads * np.random.uniform(0.01, 0.1)) # 1-10% like ratio
336
-
337
- # Generate model size (in bytes for params)
338
- # Model size should correlate somewhat with the size in the name
339
- size_indicator = 1
340
- for s in ["70b", "13b", "7b", "3b", "2b", "1b", "large", "huge", "xl", "xxl"]:
341
- if s in model_name.lower():
342
- size_indicator = float(s.replace("b", "")) if s[0].isdigit() else 3
343
- break
344
-
345
- # Size in bytes
346
- params = int(np.random.uniform(0.5, 2.0) * size_indicator * 1e9)
347
-
348
- # Create model entry
349
- model = {
350
- "id": model_id,
351
- "author": org,
352
- "downloads": downloads,
353
- "likes": likes,
354
- "pipeline_tag": pipeline_tag,
355
- "tags": tags,
356
- "params": params
357
- }
358
-
359
- sample_data.append(model)
360
-
361
- # Convert sample data to DataFrame and append to original
362
- sample_df = pd.DataFrame(sample_data)
363
- return pd.concat([df, sample_df], ignore_index=True)
364
 
365
- # Create Gradio interface
366
- with gr.Blocks() as demo:
367
- models_data = gr.State() # To store loaded data
368
-
369
  with gr.Row():
370
- gr.Markdown("""
371
- # HuggingFace Models TreeMap Visualization
372
-
373
- This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
374
- Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
375
-
376
- The treemap visualizes models grouped by organization, with the size of each box representing the selected metric (downloads or likes).
377
- """)
378
-
379
  with gr.Row():
380
- with gr.Column(scale=1):
381
- count_by_dropdown = gr.Dropdown(
382
- label="Metric",
383
- choices=["downloads", "likes"],
384
- value="downloads",
385
- info="Select the metric to determine box sizes"
386
- )
387
-
388
- filter_choice_radio = gr.Radio(
389
- label="Filter Type",
390
- choices=["None", "Tag Filter", "Pipeline Filter"],
391
- value="None",
392
- info="Choose how to filter the models"
393
- )
394
-
395
- tag_filter_dropdown = gr.Dropdown(
396
- label="Select Tag",
397
- choices=list(TAG_FILTER_FUNCS.keys()),
398
- value=None,
399
- visible=False,
400
- info="Filter models by domain/category"
401
- )
402
-
403
- pipeline_filter_dropdown = gr.Dropdown(
404
- label="Select Pipeline Tag",
405
- choices=PIPELINE_TAGS,
406
- value=None,
407
- visible=False,
408
- info="Filter models by specific pipeline"
409
- )
410
 
411
- size_filter_dropdown = gr.Dropdown(
412
- label="Model Size Filter",
413
- choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
414
- value="None",
415
- info="Filter models by their size (using params column)"
416
- )
417
 
418
- top_k_slider = gr.Slider(
419
- label="Number of Top Organizations",
420
- minimum=5,
421
- maximum=50,
422
- value=25,
423
- step=5,
424
- info="Number of top organizations to include"
425
- )
426
-
427
- skip_orgs_textbox = gr.Textbox(
428
- label="Organizations to Skip (comma-separated)",
429
- placeholder="e.g., openai, meta, huggingface",
430
- info="Enter names of organizations to exclude from the visualization"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
  )
432
 
433
- generate_plot_button = gr.Button("Generate Plot", variant="primary")
 
 
 
 
 
 
 
 
 
 
434
 
435
- with gr.Column(scale=3):
436
- plot_output = gr.Plot()
437
- stats_output = gr.Markdown("*Generate a plot to see statistics*")
 
 
 
 
 
438
 
439
- def generate_plot_on_click(count_by, filter_choice, tag_filter, pipeline_filter, size_filter, top_k, skip_orgs_text, data_df):
440
- print(f"Generating plot with: Metric={count_by}, Filter={filter_choice}, Tag={tag_filter}, Pipeline={pipeline_filter}, Size={size_filter}, Top K={top_k}")
441
-
442
- if data_df is None or len(data_df) == 0:
443
- return None, "Error: No data available. Please try again."
444
-
445
- selected_tag_filter = None
446
- selected_pipeline_filter = None
447
- selected_size_filter = None
448
-
449
- if filter_choice == "Tag Filter":
450
- selected_tag_filter = tag_filter
451
- elif filter_choice == "Pipeline Filter":
452
- selected_pipeline_filter = pipeline_filter
453
-
454
- if size_filter != "None":
455
- selected_size_filter = size_filter
456
-
457
- # Process skip organizations list
458
- skip_orgs = []
459
- if skip_orgs_text and skip_orgs_text.strip():
460
- skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
461
- print(f"Skipping organizations: {skip_orgs}")
462
-
463
- # Process data for treemap
464
- treemap_data = make_treemap_data(
465
- df=data_df,
466
- count_by=count_by,
467
- top_k=top_k,
468
- tag_filter=selected_tag_filter,
469
- pipeline_filter=selected_pipeline_filter,
470
- size_filter=selected_size_filter,
471
- skip_orgs=skip_orgs
472
- )
473
-
474
- # Create plot
475
- fig = create_treemap(
476
- treemap_data=treemap_data,
477
- count_by=count_by,
478
- title=f"HuggingFace Models - {count_by.capitalize()} by Organization"
479
- )
480
 
481
- # Generate statistics
482
- if treemap_data.empty:
483
- stats_md = "No data matches the selected filters."
484
- else:
485
- total_models = len(treemap_data)
486
- total_value = treemap_data[count_by].sum()
487
- top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
488
-
489
- # Format the statistics using clean markdown
490
- stats_md = f"""
491
- ## Statistics
492
- - **Total models shown**: {total_models:,}
493
- - **Total {count_by}**: {int(total_value):,}
494
 
495
- ## Top Organizations by {count_by.capitalize()}
496
 
497
- | Organization | {count_by.capitalize()} | % of Total |
498
- |--------------|--------:|--------:|"""
499
-
500
- # Add each organization as a row in the table
501
- for org, value in top_5_orgs.items():
502
- percentage = (value / total_value) * 100
503
- stats_md += f"\n| {org} | {int(value):,} | {percentage:.2f}% |"
504
-
505
- # Add note about skipped organizations if any
506
- if skip_orgs:
507
- stats_md += f"\n\n*Note: {len(skip_orgs)} organization(s) excluded: {', '.join(skip_orgs)}*"
 
 
 
 
 
 
 
508
 
509
- return fig, stats_md
510
-
511
- def update_filter_visibility(filter_choice):
512
- if filter_choice == "Tag Filter":
513
- return gr.update(visible=True), gr.update(visible=False)
514
- elif filter_choice == "Pipeline Filter":
515
- return gr.update(visible=False), gr.update(visible=True)
516
- else: # "None"
517
- return gr.update(visible=False), gr.update(visible=False)
518
-
519
- filter_choice_radio.change(
520
- fn=update_filter_visibility,
521
- inputs=[filter_choice_radio],
522
- outputs=[tag_filter_dropdown, pipeline_filter_dropdown]
523
- )
524
-
525
- # Load data once at startup
526
  demo.load(
527
- fn=load_models_csv,
 
 
 
 
 
 
 
528
  inputs=[],
529
- outputs=[models_data]
530
  )
531
 
532
- # Button click event to generate plot
533
  generate_plot_button.click(
534
- fn=generate_plot_on_click,
535
- inputs=[
536
- count_by_dropdown,
537
- filter_choice_radio,
538
- tag_filter_dropdown,
539
- pipeline_filter_dropdown,
540
- size_filter_dropdown,
541
- top_k_slider,
542
- skip_orgs_textbox,
543
- models_data
544
- ],
545
- outputs=[plot_output, stats_output]
546
  )
547
 
548
-
549
  if __name__ == "__main__":
550
- demo.launch()
 
 
 
 
 
 
 
1
+ # --- START OF FILE app.py ---
2
+
3
  import json
4
  import gradio as gr
5
  import pandas as pd
6
  import plotly.express as px
7
  import os
8
+ import numpy as np # Make sure NumPy is imported
9
+ import duckdb
10
+ from tqdm.auto import tqdm # Standard tqdm for console, gr.Progress will track it
11
+ import time
12
+ import ast # For safely evaluating string representations of lists/dicts
13
+
14
+ # --- Constants ---
15
+ MODEL_SIZE_RANGES = {
16
+ "Small (<1GB)": (0, 1), "Medium (1-5GB)": (1, 5), "Large (5-20GB)": (5, 20),
17
+ "X-Large (20-50GB)": (20, 50), "XX-Large (>50GB)": (50, float('inf'))
18
+ }
19
+ PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
20
+ HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet'
21
+
22
+ TAG_FILTER_CHOICES = [
23
+ "Audio & Speech", "Time series", "Robotics", "Music", "Video", "Images",
24
+ "Text", "Biomedical", "Sciences"
25
+ ]
26
 
 
27
  PIPELINE_TAGS = [
28
+ 'text-generation', 'text-to-image', 'text-classification', 'text2text-generation',
29
+ 'audio-to-audio', 'feature-extraction', 'image-classification', 'translation',
30
+ 'reinforcement-learning', 'fill-mask', 'text-to-speech', 'automatic-speech-recognition',
31
+ 'image-text-to-text', 'token-classification', 'sentence-similarity', 'question-answering',
32
+ 'image-feature-extraction', 'summarization', 'zero-shot-image-classification',
33
+ 'object-detection', 'image-segmentation', 'image-to-image', 'image-to-text',
34
+ 'audio-classification', 'visual-question-answering', 'text-to-video',
35
+ 'zero-shot-classification', 'depth-estimation', 'text-ranking', 'image-to-video',
36
+ 'multiple-choice', 'unconditional-image-generation', 'video-classification',
37
+ 'text-to-audio', 'time-series-forecasting', 'any-to-any', 'video-text-to-text',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  'table-question-answering',
39
  ]
40
 
41
+ # --- Utility Functions ---
42
+ def extract_model_size(safetensors_data): # Renamed for consistency if used, preprocessor uses extract_model_file_size_gb
43
+ try:
44
+ if pd.isna(safetensors_data): return 0.0
45
+ data_to_parse = safetensors_data
46
+ if isinstance(safetensors_data, str):
47
+ try:
48
+ if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
49
+ (safetensors_data.startswith('[') and safetensors_data.endswith(']')):
50
+ data_to_parse = ast.literal_eval(safetensors_data)
51
+ else: data_to_parse = json.loads(safetensors_data)
52
+ except: return 0.0
53
+ if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
54
+ try:
55
+ total_bytes_val = data_to_parse['total']
56
+ size_bytes = float(total_bytes_val)
57
+ return size_bytes / (1024 * 1024 * 1024)
58
+ except (ValueError, TypeError): pass
59
+ return 0.0
60
+ except: return 0.0
61
+
62
+ def extract_org_from_id(model_id):
63
+ if pd.isna(model_id): return "unaffiliated"
64
+ model_id_str = str(model_id)
65
+ return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
66
 
67
+ # --- THIS IS THE CORRECTED process_tags_for_series from preprocess.py ---
68
+ def process_tags_for_series(series_of_tags_values, tqdm_cls=None): # Added tqdm_cls for Gradio progress
69
+ processed_tags_accumulator = []
 
70
 
71
+ # Determine the iterable (use tqdm if tqdm_cls is provided, else direct iteration)
72
+ iterable = series_of_tags_values
73
+ if tqdm_cls and tqdm_cls != tqdm : # Check if it's Gradio's progress tracker
74
+ iterable = tqdm_cls(series_of_tags_values, desc="Standardizing Tags (App)", unit="row")
75
+ elif tqdm_cls == tqdm: # For direct console tqdm if passed
76
+ iterable = tqdm(series_of_tags_values, desc="Standardizing Tags (App)", unit="row", leave=False)
77
+
78
+
79
+ for i, tags_value_from_series in enumerate(iterable):
80
+ temp_processed_list_for_row = []
81
+ current_value_for_error_msg = str(tags_value_from_series)[:200]
82
+
83
+ try:
84
+ if isinstance(tags_value_from_series, list):
85
+ current_tags_in_list = []
86
+ for tag_item in tags_value_from_series:
87
+ try:
88
+ if pd.isna(tag_item): continue
89
+ str_tag = str(tag_item)
90
+ stripped_tag = str_tag.strip()
91
+ if stripped_tag:
92
+ current_tags_in_list.append(stripped_tag)
93
+ except Exception as e_inner_list_proc:
94
+ print(f"APP ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original: {current_value_for_error_msg}")
95
+ temp_processed_list_for_row = current_tags_in_list
96
+
97
+ elif isinstance(tags_value_from_series, np.ndarray):
98
+ current_tags_in_list = []
99
+ for tag_item in tags_value_from_series.tolist():
100
+ try:
101
+ if pd.isna(tag_item): continue
102
+ str_tag = str(tag_item)
103
+ stripped_tag = str_tag.strip()
104
+ if stripped_tag:
105
+ current_tags_in_list.append(stripped_tag)
106
+ except Exception as e_inner_array_proc:
107
+ print(f"APP ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original: {current_value_for_error_msg}")
108
+ temp_processed_list_for_row = current_tags_in_list
109
+
110
+ elif tags_value_from_series is None or pd.isna(tags_value_from_series):
111
+ temp_processed_list_for_row = []
112
+
113
+ elif isinstance(tags_value_from_series, str):
114
+ processed_str_tags = []
115
+ if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
116
+ (tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
117
+ try:
118
+ evaluated_tags = ast.literal_eval(tags_value_from_series)
119
+ if isinstance(evaluated_tags, (list, tuple)):
120
+ current_eval_list = []
121
+ for tag_item in evaluated_tags:
122
+ if pd.isna(tag_item): continue
123
+ str_tag = str(tag_item).strip()
124
+ if str_tag: current_eval_list.append(str_tag)
125
+ processed_str_tags = current_eval_list
126
+ except (ValueError, SyntaxError):
127
+ pass
128
+
129
+ if not processed_str_tags:
130
+ try:
131
+ json_tags = json.loads(tags_value_from_series)
132
+ if isinstance(json_tags, list):
133
+ current_json_list = []
134
+ for tag_item in json_tags:
135
+ if pd.isna(tag_item): continue
136
+ str_tag = str(tag_item).strip()
137
+ if str_tag: current_json_list.append(str_tag)
138
+ processed_str_tags = current_json_list
139
+ except json.JSONDecodeError:
140
+ processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
141
+ except Exception as e_json_other:
142
+ print(f"APP ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
143
+ processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
144
+
145
+ temp_processed_list_for_row = processed_str_tags
146
+
147
+ else:
148
+ if pd.isna(tags_value_from_series):
149
+ temp_processed_list_for_row = []
150
+ else:
151
+ str_val = str(tags_value_from_series).strip()
152
+ temp_processed_list_for_row = [str_val] if str_val else []
153
+
154
+ processed_tags_accumulator.append(temp_processed_list_for_row)
155
+
156
+ except Exception as e_outer_tag_proc:
157
+ print(f"APP CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
158
+ processed_tags_accumulator.append([])
159
+
160
+ return processed_tags_accumulator
161
+ # --- END OF CORRECTED process_tags_for_series ---
162
+
163
+
164
+ def load_models_data(force_refresh=False, tqdm_cls=None): # tqdm_cls for Gradio progress
165
+ # ... (initial part of load_models_data for loading pre-processed parquet is the same) ...
166
+ if tqdm_cls is None: tqdm_cls = tqdm # Default to standard tqdm if None
167
+ overall_start_time = time.time()
168
+ print(f"Gradio load_models_data called with force_refresh={force_refresh}")
169
+
170
+ expected_cols_in_processed_parquet = [
171
+ 'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params',
172
+ 'size_category', 'organization', 'has_audio', 'has_speech', 'has_music',
173
+ 'has_robot', 'has_bio', 'has_med', 'has_series', 'has_video', 'has_image',
174
+ 'has_text', 'has_science', 'is_audio_speech', 'is_biomed',
175
+ 'data_download_timestamp'
176
+ ]
177
+
178
+ if not force_refresh and os.path.exists(PROCESSED_PARQUET_FILE_PATH):
179
+ print(f"Attempting to load pre-processed data from: {PROCESSED_PARQUET_FILE_PATH}")
180
+ try:
181
+ df = pd.read_parquet(PROCESSED_PARQUET_FILE_PATH)
182
+ elapsed = time.time() - overall_start_time
183
+ missing_cols = [col for col in expected_cols_in_processed_parquet if col not in df.columns]
184
+ if missing_cols:
185
+ raise ValueError(f"Pre-processed Parquet is missing columns: {missing_cols}. Please run preprocessor or refresh data in app.")
186
+
187
+ if 'has_robot' in df.columns:
188
+ robot_count_parquet = df['has_robot'].sum()
189
+ print(f"DIAGNOSTIC (App - Parquet Load): 'has_robot' column found. Number of True values: {robot_count_parquet}")
190
+ else:
191
+ print("DIAGNOSTIC (App - Parquet Load): 'has_robot' column NOT FOUND.")
192
+
193
+ msg = f"Successfully loaded pre-processed data in {elapsed:.2f}s. Shape: {df.shape}"
194
+ print(msg)
195
+ return df, True, msg
196
+ except Exception as e:
197
+ print(f"Could not load pre-processed Parquet: {e}. ")
198
+ if force_refresh: print("Proceeding to fetch fresh data as force_refresh=True.")
199
+ else:
200
+ err_msg = (f"Pre-processed data could not be loaded: {e}. "
201
+ "Please use 'Refresh Data from Hugging Face' button.")
202
+ return pd.DataFrame(), False, err_msg
203
+
204
+ df_raw = None
205
+ raw_data_source_msg = ""
206
+ if force_refresh:
207
+ print("force_refresh=True (Gradio). Fetching fresh data...")
208
+ fetch_start = time.time()
209
+ try:
210
+ query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')"
211
+ df_raw = duckdb.sql(query).df()
212
+ if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
213
+ raw_data_source_msg = f"Fetched by Gradio in {time.time() - fetch_start:.2f}s. Rows: {len(df_raw)}"
214
+ print(raw_data_source_msg)
215
+ except Exception as e_hf:
216
+ return pd.DataFrame(), False, f"Fatal error fetching from Hugging Face (Gradio): {e_hf}"
217
+ else:
218
+ err_msg = (f"Pre-processed data '{PROCESSED_PARQUET_FILE_PATH}' not found/invalid. "
219
+ "Run preprocessor or use 'Refresh Data' button.")
220
+ return pd.DataFrame(), False, err_msg
221
+
222
+ print(f"Initiating processing for data newly fetched by Gradio. {raw_data_source_msg}")
223
+ df = pd.DataFrame() # This will be our processed DataFrame
224
+ proc_start = time.time()
225
 
226
+ core_cols = {'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
227
+ 'pipeline_tag': str, 'tags': object, 'safetensors': object}
228
+ for col, dtype in core_cols.items():
229
+ if col in df_raw.columns:
230
+ df[col] = df_raw[col] # Assign raw data first
231
+ if dtype == float: df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
232
+ elif dtype == str: df[col] = df[col].astype(str).fillna('')
233
+ # For 'tags' and 'safetensors' (object type), no specific conversion here, done later
234
+ else: # If column is missing in raw data
235
+ if col in ['downloads', 'downloadsAllTime', 'likes']: df[col] = 0.0
236
+ elif col == 'pipeline_tag': df[col] = ''
237
+ elif col == 'tags': df[col] = pd.Series([[] for _ in range(len(df_raw))]) # Default to empty lists
238
+ elif col == 'safetensors': df[col] = None # Default to None
239
+ elif col == 'id': return pd.DataFrame(), False, "Critical: 'id' column missing."
240
 
241
+ output_filesize_col_name = 'params'
242
+ if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
243
+ df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
244
+ elif 'safetensors' in df.columns:
245
+ # Use tqdm_cls for progress tracking if available (Gradio's gr.Progress.tqdm)
246
+ safetensors_iter = df['safetensors']
247
+ if tqdm_cls and tqdm_cls != tqdm: # Check if it's Gradio's progress tracker
248
+ safetensors_iter = tqdm_cls(df['safetensors'], desc="Extracting model sizes (GB)", unit="row")
249
+ elif tqdm_cls == tqdm: # For direct console tqdm if passed
250
+ safetensors_iter = tqdm(df['safetensors'], desc="Extracting model sizes (GB)", unit="row", leave=False)
251
+
252
+ df[output_filesize_col_name] = [extract_model_size(s) for s in safetensors_iter]
253
+ df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
254
+ else:
255
+ df[output_filesize_col_name] = 0.0
256
+
257
+ def get_size_category_gradio(size_gb_val):
258
+ try: numeric_size_gb = float(size_gb_val)
259
+ except (ValueError, TypeError): numeric_size_gb = 0.0
260
+ if pd.isna(numeric_size_gb): numeric_size_gb = 0.0
261
+ if 0 <= numeric_size_gb < 1: return "Small (<1GB)"
262
+ elif 1 <= numeric_size_gb < 5: return "Medium (1-5GB)"
263
+ elif 5 <= numeric_size_gb < 20: return "Large (5-20GB)"
264
+ elif 20 <= numeric_size_gb < 50: return "X-Large (20-50GB)"
265
+ elif numeric_size_gb >= 50: return "XX-Large (>50GB)"
266
+ else: return "Small (<1GB)" # Default
267
+ df['size_category'] = df[output_filesize_col_name].apply(get_size_category_gradio)
268
+
269
+ # >>> USE THE CORRECTED process_tags_for_series HERE <<<
270
+ df['tags'] = process_tags_for_series(df['tags'], tqdm_cls=tqdm_cls)
271
 
272
+ df['temp_tags_joined'] = df['tags'].apply(
273
+ lambda tl: '~~~'.join(str(t).lower().strip() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
274
+ )
275
+ tag_map = {
276
+ 'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
277
+ 'has_robot': ['robot', 'robotics'],
278
+ 'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
279
+ 'has_series': ['series', 'time-series', 'timeseries'],
280
+ 'has_video': ['video'], 'has_image': ['image', 'vision'],
281
+ 'has_text': ['text', 'nlp', 'llm']
282
+ }
283
+ for col, kws in tag_map.items():
284
+ pattern = '|'.join(kws)
285
+ df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
286
+ df['has_science'] = (
287
+ df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
288
+ ~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
289
+ )
290
+ del df['temp_tags_joined']
291
+ df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
292
+ df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
293
+ df['is_biomed'] = df['has_bio'] | df['has_med']
294
+ df['organization'] = df['id'].apply(extract_org_from_id)
295
 
296
+ # Drop safetensors if params was calculated from it, and params didn't pre-exist as numeric
297
+ if 'safetensors' in df.columns and \
298
+ not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
299
+ df = df.drop(columns=['safetensors'], errors='ignore')
300
+
301
+ if force_refresh and 'has_robot' in df.columns:
302
+ robot_count_app_proc = df['has_robot'].sum()
303
+ print(f"DIAGNOSTIC (App - Force Refresh Processing): 'has_robot' column processed. Number of True values: {robot_count_app_proc}")
304
 
305
+ print(f"Data processing by Gradio completed in {time.time() - proc_start:.2f}s.")
306
+
307
+ total_elapsed = time.time() - overall_start_time
308
+ final_msg = f"{raw_data_source_msg}. Processing by Gradio took {time.time() - proc_start:.2f}s. Total: {total_elapsed:.2f}s. Shape: {df.shape}"
309
+ print(final_msg)
310
+ return df, True, final_msg
311
+
312
+
313
+ # ... (make_treemap_data, create_treemap functions remain unchanged) ...
314
  def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
315
+ if df is None or df.empty: return pd.DataFrame()
 
316
  filtered_df = df.copy()
317
+ col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot",
318
+ "Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science",
319
+ "Video": "has_video", "Images": "has_image", "Text": "has_text"}
320
 
321
+ if 'has_robot' in filtered_df.columns:
322
+ initial_robot_count = filtered_df['has_robot'].sum()
323
+ # print(f"DIAGNOSTIC (make_treemap_data entry): Input df has {initial_robot_count} 'has_robot' models.") # Can be noisy
324
+ # else:
325
+ # print("DIAGNOSTIC (make_treemap_data entry): 'has_robot' column NOT in input df.")
326
+
327
+ if tag_filter and tag_filter in col_map:
328
+ target_col = col_map[tag_filter]
329
+ if target_col in filtered_df.columns:
330
+ # if tag_filter == "Robotics":
331
+ # count_before_robot_filter = filtered_df[target_col].sum()
332
+ # print(f"DIAGNOSTIC (make_treemap_data): Applying 'Robotics' filter. Models with '{target_col}'=True: {count_before_robot_filter}")
333
+ filtered_df = filtered_df[filtered_df[target_col]]
334
+ # if tag_filter == "Robotics":
335
+ # print(f"DIAGNOSTIC (make_treemap_data): After 'Robotics' filter ({target_col}), df rows: {len(filtered_df)}")
336
+ else:
337
+ print(f"Warning: Tag filter column '{col_map[tag_filter]}' not found in DataFrame.")
338
  if pipeline_filter:
339
+ if "pipeline_tag" in filtered_df.columns:
340
+ filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
341
+ else:
342
+ print(f"Warning: 'pipeline_tag' column not found for filtering.")
343
+ if size_filter and size_filter != "None" and size_filter in MODEL_SIZE_RANGES.keys():
344
+ if 'size_category' in filtered_df.columns:
345
+ filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
346
+ else:
347
+ print("Warning: 'size_category' column not found for filtering.")
 
 
 
 
348
  if skip_orgs and len(skip_orgs) > 0:
349
+ if "organization" in filtered_df.columns:
350
+ filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
351
+ else:
352
+ print("Warning: 'organization' column not found for filtering.")
353
+ if filtered_df.empty: return pd.DataFrame()
354
+ # Ensure count_by column is numeric, coercing if necessary
355
+ if count_by not in filtered_df.columns or not pd.api.types.is_numeric_dtype(filtered_df[count_by]):
356
+ # print(f"Warning: Column '{count_by}' for treemap values is not numeric or missing. Coercing to numeric, filling NaNs with 0.")
357
+ filtered_df[count_by] = pd.to_numeric(filtered_df.get(count_by), errors="coerce").fillna(0.0)
358
+
359
+ org_totals = filtered_df.groupby("organization")[count_by].sum().nlargest(top_k, keep='first') # Default keep='first'
360
+ top_orgs_list = org_totals.index.tolist()
361
+ treemap_data = filtered_df[filtered_df["organization"].isin(top_orgs_list)][["id", "organization", count_by]].copy()
362
+ treemap_data["root"] = "models" # For treemap structure
363
+ treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0.0) # Ensure numeric again after subsetting
 
 
 
 
 
 
364
  return treemap_data
365
 
366
  def create_treemap(treemap_data, count_by, title=None):
 
367
  if treemap_data.empty:
368
+ fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1]) # Placeholder for empty data
369
+ fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
 
 
 
 
 
 
 
370
  return fig
 
 
371
  fig = px.treemap(
372
+ treemap_data, path=["root", "organization", "id"], values=count_by,
 
 
373
  title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
374
+ color_discrete_sequence=px.colors.qualitative.Plotly # Example color sequence
 
 
 
 
 
 
 
 
 
 
 
375
  )
376
+ fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
377
+ fig.update_traces(textinfo="label+value+percent root", hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>")
378
  return fig
379
 
380
+ # --- Gradio UI and Controllers ---
381
+ with gr.Blocks(title="HuggingFace Model Explorer") as demo:
382
+ models_data_state = gr.State(pd.DataFrame())
383
+ loading_complete_state = gr.State(False) # To control button interactivity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
 
 
 
 
 
385
  with gr.Row():
386
+ gr.Markdown("# HuggingFace Models TreeMap Visualization")
 
 
 
 
 
 
 
 
387
  with gr.Row():
388
+ with gr.Column(scale=1): # Controls column
389
+ count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
390
+ filter_choice_radio = gr.Radio(label="Filter Type", choices=["None", "Tag Filter", "Pipeline Filter"], value="None")
391
+ tag_filter_dropdown = gr.Dropdown(label="Select Tag", choices=TAG_FILTER_CHOICES, value=None, visible=False)
392
+ pipeline_filter_dropdown = gr.Dropdown(label="Select Pipeline Tag", choices=PIPELINE_TAGS, value=None, visible=False)
393
+ size_filter_dropdown = gr.Dropdown(label="Model Size Filter", choices=["None"] + list(MODEL_SIZE_RANGES.keys()), value="None")
394
+ top_k_slider = gr.Slider(label="Number of Top Organizations", minimum=5, maximum=50, value=25, step=5)
395
+ skip_orgs_textbox = gr.Textbox(label="Organizations to Skip (comma-separated)", value="TheBloke,MaziyarPanahi,unsloth,modularai,Gensyn,bartowski") # Common large orgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
+ generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False) # Starts disabled
398
+ refresh_data_button = gr.Button(value="Refresh Data from Hugging Face", variant="secondary")
 
 
 
 
399
 
400
+ with gr.Column(scale=3): # Plot and info column
401
+ plot_output = gr.Plot()
402
+ status_message_md = gr.Markdown("Initializing...") # For general status updates
403
+ data_info_md = gr.Markdown("") # For detailed data stats
404
+
405
+ # Enable generate button only after data is loaded
406
+ def _update_button_interactivity(is_loaded_flag):
407
+ return gr.update(interactive=is_loaded_flag)
408
+ loading_complete_state.change(fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button)
409
+
410
+ # Show/hide tag/pipeline filters based on radio choice
411
+ def _toggle_filters_visibility(choice):
412
+ return gr.update(visible=choice == "Tag Filter"), gr.update(visible=choice == "Pipeline Filter")
413
+ filter_choice_radio.change(fn=_toggle_filters_visibility, inputs=filter_choice_radio, outputs=[tag_filter_dropdown, pipeline_filter_dropdown])
414
+
415
+
416
+ def ui_load_data_controller(force_refresh_ui_trigger=False, progress=gr.Progress(track_tqdm=True)): # Gradio progress tracker
417
+ print(f"ui_load_data_controller called with force_refresh_ui_trigger={force_refresh_ui_trigger}")
418
+ status_msg_ui = "Loading data..."
419
+ data_info_text = ""
420
+ current_df = pd.DataFrame()
421
+ load_success_flag = False
422
+ data_as_of_date_display = "N/A"
423
+
424
+ try:
425
+ # Pass gr.Progress.tqdm to load_models_data if it's a Gradio call
426
+ current_df, load_success_flag, status_msg_from_load = load_models_data(
427
+ force_refresh=force_refresh_ui_trigger, tqdm_cls=progress.tqdm if progress else tqdm
428
  )
429
 
430
+ if load_success_flag:
431
+ if force_refresh_ui_trigger: # Data was just fetched by Gradio
432
+ data_as_of_date_display = pd.Timestamp.now(tz='UTC').strftime('%B %d, %Y, %H:%M:%S %Z')
433
+ # If loaded from pre-processed parquet, check for its timestamp column
434
+ elif 'data_download_timestamp' in current_df.columns and not current_df.empty and pd.notna(current_df['data_download_timestamp'].iloc[0]):
435
+ timestamp_from_parquet = pd.to_datetime(current_df['data_download_timestamp'].iloc[0])
436
+ if timestamp_from_parquet.tzinfo is None: # If no timezone, assume UTC from preprocessor
437
+ timestamp_from_parquet = timestamp_from_parquet.tz_localize('UTC')
438
+ data_as_of_date_display = timestamp_from_parquet.strftime('%B %d, %Y, %H:%M:%S %Z')
439
+ else: # Pre-processed data but no timestamp column or it's NaT
440
+ data_as_of_date_display = "Pre-processed (date unavailable)"
441
 
442
+ # Build data info string
443
+ size_dist_lines = []
444
+ if 'size_category' in current_df.columns:
445
+ for cat in MODEL_SIZE_RANGES.keys():
446
+ count = (current_df['size_category'] == cat).sum()
447
+ size_dist_lines.append(f" - {cat}: {count:,} models")
448
+ else: size_dist_lines.append(" - Size category information not available.")
449
+ size_dist = "\n".join(size_dist_lines)
450
 
451
+ data_info_text = (f"### Data Information\n"
452
+ f"- Overall Status: {status_msg_from_load}\n"
453
+ f"- Total models loaded: {len(current_df):,}\n"
454
+ f"- Data as of: {data_as_of_date_display}\n"
455
+ f"- Size categories:\n{size_dist}")
456
+
457
+ if not current_df.empty and 'has_robot' in current_df.columns:
458
+ robot_true_count = current_df['has_robot'].sum()
459
+ data_info_text += f"\n- **Models flagged 'has_robot'**: {robot_true_count}"
460
+ if 0 < robot_true_count <= 10:
461
+ sample_robot_ids = current_df[current_df['has_robot']]['id'].head(5).tolist()
462
+ data_info_text += f"\n - Sample 'has_robot' model IDs: `{', '.join(sample_robot_ids)}`"
463
+ elif not current_df.empty:
464
+ data_info_text += "\n- **Models flagged 'has_robot'**: 'has_robot' column not found."
465
+
466
+ status_msg_ui = "Data loaded successfully. Ready to generate plot."
467
+ else: # load_success_flag is False
468
+ data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
469
+ status_msg_ui = status_msg_from_load # Pass error message from load_models_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
+ except Exception as e:
472
+ status_msg_ui = f"An unexpected error occurred in ui_load_data_controller: {str(e)}"
473
+ data_info_text = f"### Critical Error\n- {status_msg_ui}"
474
+ print(f"Critical error in ui_load_data_controller: {e}")
475
+ load_success_flag = False # Ensure this is false on error
 
 
 
 
 
 
 
 
476
 
477
+ return current_df, load_success_flag, data_info_text, status_msg_ui
478
 
479
+ def ui_generate_plot_controller(metric_choice, filter_type, tag_choice, pipeline_choice,
480
+ size_choice, k_orgs, skip_orgs_input, df_current_models):
481
+ if df_current_models is None or df_current_models.empty:
482
+ empty_fig = create_treemap(pd.DataFrame(), metric_choice, "Error: Model Data Not Loaded")
483
+ error_msg = "Model data is not loaded or is empty. Please load or refresh data first."
484
+ gr.Warning(error_msg) # Display a Gradio warning
485
+ return empty_fig, error_msg
486
+
487
+ tag_to_use = tag_choice if filter_type == "Tag Filter" else None
488
+ pipeline_to_use = pipeline_choice if filter_type == "Pipeline Filter" else None
489
+ size_to_use = size_choice if size_choice != "None" else None # Handle "None" string
490
+ orgs_to_skip = [org.strip() for org in skip_orgs_input.split(',') if org.strip()] if skip_orgs_input else []
491
+
492
+ # if 'has_robot' in df_current_models.columns:
493
+ # robot_count_before_treemap = df_current_models['has_robot'].sum()
494
+ # print(f"DIAGNOSTIC (ui_generate_plot_controller): df_current_models entering make_treemap_data has {robot_count_before_treemap} 'has_robot' models.")
495
+
496
+ treemap_df = make_treemap_data(df_current_models, metric_choice, k_orgs, tag_to_use, pipeline_to_use, size_to_use, orgs_to_skip)
497
 
498
+ title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
499
+ chart_title = f"HuggingFace Models - {title_labels.get(metric_choice, metric_choice)} by Organization"
500
+ plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
501
+
502
+ if treemap_df.empty:
503
+ plot_stats_md = "No data matches the selected filters. Try adjusting your filters."
504
+ else:
505
+ total_items_in_plot = len(treemap_df['id'].unique()) # Count unique models in plot
506
+ total_value_in_plot = treemap_df[metric_choice].sum() # Sum of metric in plot
507
+ plot_stats_md = (f"## Plot Statistics\n- **Models shown**: {total_items_in_plot:,}\n- **Total {metric_choice}**: {int(total_value_in_plot):,}")
508
+
509
+ return plotly_fig, plot_stats_md
510
+
511
+ # --- Event Handlers ---
512
+ # Initial data load on app start
 
 
513
  demo.load(
514
+ fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=False, progress=progress),
515
+ inputs=[], # No inputs for initial load
516
+ outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
517
+ )
518
+
519
+ # Refresh data button
520
+ refresh_data_button.click(
521
+ fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=True, progress=progress),
522
  inputs=[],
523
+ outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
524
  )
525
 
526
+ # Generate plot button
527
  generate_plot_button.click(
528
+ fn=ui_generate_plot_controller,
529
+ inputs=[count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown,
530
+ size_filter_dropdown, top_k_slider, skip_orgs_textbox, models_data_state],
531
+ outputs=[plot_output, status_message_md] # Update plot and status message
 
 
 
 
 
 
 
 
532
  )
533
 
 
534
  if __name__ == "__main__":
535
+ if not os.path.exists(PROCESSED_PARQUET_FILE_PATH):
536
+ print(f"WARNING: Pre-processed data file '{PROCESSED_PARQUET_FILE_PATH}' not found.")
537
+ print("It is highly recommended to run the preprocessing script (preprocess.py) first.")
538
+ else:
539
+ print(f"Found pre-processed data file: '{PROCESSED_PARQUET_FILE_PATH}'.")
540
+ demo.launch()
541
+ # --- END OF FILE app.py ---
models.csv → models_processed.parquet RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:72304e5cf712934720d0fb6e422ccee0d6d89331e9fa9c8a899dc8589d408654
3
- size 318247426
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:998afad6c0c4c64f9e98efd8609d1cbab1dd2ac281b9c2e023878ad436c2fbde
3
+ size 96033487
preprocess.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --- START OF FILE preprocess.py ---
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ import json
6
+ import ast
7
+ from tqdm.auto import tqdm
8
+ import time
9
+ import os
10
+ import duckdb
11
+ import re # Import re for the manual regex check in debug
12
+
13
+ # --- Constants ---
14
+ PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
15
+ HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet'
16
+
17
+ MODEL_SIZE_RANGES = {
18
+ "Small (<1GB)": (0, 1),
19
+ "Medium (1-5GB)": (1, 5),
20
+ "Large (5-20GB)": (5, 20),
21
+ "X-Large (20-50GB)": (20, 50),
22
+ "XX-Large (>50GB)": (50, float('inf'))
23
+ }
24
+
25
+ # --- Debugging Constant ---
26
+ # <<<<<<< SET THE MODEL ID YOU WANT TO DEBUG HERE >>>>>>>
27
+ MODEL_ID_TO_DEBUG = "openvla/openvla-7b"
28
+ # Example: MODEL_ID_TO_DEBUG = "openai-community/gpt2"
29
+ # If you don't have a specific ID, the debug block will just report it's not found.
30
+
31
+ # --- Utility Functions (extract_model_file_size_gb, extract_org_from_id, process_tags_for_series, get_file_size_category - unchanged from previous correct version) ---
32
+ def extract_model_file_size_gb(safetensors_data):
33
+ try:
34
+ if pd.isna(safetensors_data): return 0.0
35
+ data_to_parse = safetensors_data
36
+ if isinstance(safetensors_data, str):
37
+ try:
38
+ if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
39
+ (safetensors_data.startswith('[') and safetensors_data.endswith(']')):
40
+ data_to_parse = ast.literal_eval(safetensors_data)
41
+ else: data_to_parse = json.loads(safetensors_data)
42
+ except Exception: return 0.0
43
+ if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
44
+ total_bytes_val = data_to_parse['total']
45
+ try:
46
+ size_bytes = float(total_bytes_val)
47
+ return size_bytes / (1024 * 1024 * 1024)
48
+ except (ValueError, TypeError): return 0.0
49
+ return 0.0
50
+ except Exception: return 0.0
51
+
52
+ def extract_org_from_id(model_id):
53
+ if pd.isna(model_id): return "unaffiliated"
54
+ model_id_str = str(model_id)
55
+ return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
56
+
57
+ def process_tags_for_series(series_of_tags_values):
58
+ processed_tags_accumulator = []
59
+
60
+ for i, tags_value_from_series in enumerate(tqdm(series_of_tags_values, desc="Standardizing Tags", leave=False, unit="row")):
61
+ temp_processed_list_for_row = []
62
+ current_value_for_error_msg = str(tags_value_from_series)[:200] # Truncate for long error messages
63
+
64
+ try:
65
+ # Order of checks is important!
66
+ # 1. Handle explicit Python lists first
67
+ if isinstance(tags_value_from_series, list):
68
+ current_tags_in_list = []
69
+ for idx_tag, tag_item in enumerate(tags_value_from_series):
70
+ try:
71
+ # Ensure item is not NaN before string conversion if it might be a float NaN in a list
72
+ if pd.isna(tag_item): continue
73
+ str_tag = str(tag_item)
74
+ stripped_tag = str_tag.strip()
75
+ if stripped_tag:
76
+ current_tags_in_list.append(stripped_tag)
77
+ except Exception as e_inner_list_proc:
78
+ print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original list: {current_value_for_error_msg}")
79
+ temp_processed_list_for_row = current_tags_in_list
80
+
81
+ # 2. Handle NumPy arrays
82
+ elif isinstance(tags_value_from_series, np.ndarray):
83
+ # Convert to list, then process elements, handling potential NaNs within the array
84
+ current_tags_in_list = []
85
+ for idx_tag, tag_item in enumerate(tags_value_from_series.tolist()): # .tolist() is crucial
86
+ try:
87
+ if pd.isna(tag_item): continue # Check for NaN after converting to Python type
88
+ str_tag = str(tag_item)
89
+ stripped_tag = str_tag.strip()
90
+ if stripped_tag:
91
+ current_tags_in_list.append(stripped_tag)
92
+ except Exception as e_inner_array_proc:
93
+ print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original array: {current_value_for_error_msg}")
94
+ temp_processed_list_for_row = current_tags_in_list
95
+
96
+ # 3. Handle simple None or pd.NA after lists and arrays (which might contain pd.NA elements handled above)
97
+ elif tags_value_from_series is None or pd.isna(tags_value_from_series): # Now pd.isna is safe for scalars
98
+ temp_processed_list_for_row = []
99
+
100
+ # 4. Handle strings (could be JSON-like, list-like, or comma-separated)
101
+ elif isinstance(tags_value_from_series, str):
102
+ processed_str_tags = []
103
+ # Attempt ast.literal_eval for strings that look like lists/tuples
104
+ if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
105
+ (tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
106
+ try:
107
+ evaluated_tags = ast.literal_eval(tags_value_from_series)
108
+ if isinstance(evaluated_tags, (list, tuple)): # Check if eval result is a list/tuple
109
+ # Recursively process this evaluated list/tuple, as its elements could be complex
110
+ # For simplicity here, assume elements are simple strings after eval
111
+ current_eval_list = []
112
+ for tag_item in evaluated_tags:
113
+ if pd.isna(tag_item): continue
114
+ str_tag = str(tag_item).strip()
115
+ if str_tag: current_eval_list.append(str_tag)
116
+ processed_str_tags = current_eval_list
117
+ except (ValueError, SyntaxError):
118
+ pass # If ast.literal_eval fails, let it fall to JSON or comma split
119
+
120
+ # If ast.literal_eval didn't populate, try JSON
121
+ if not processed_str_tags:
122
+ try:
123
+ json_tags = json.loads(tags_value_from_series)
124
+ if isinstance(json_tags, list):
125
+ # Similar to above, assume elements are simple strings after JSON parsing
126
+ current_json_list = []
127
+ for tag_item in json_tags:
128
+ if pd.isna(tag_item): continue
129
+ str_tag = str(tag_item).strip()
130
+ if str_tag: current_json_list.append(str_tag)
131
+ processed_str_tags = current_json_list
132
+ except json.JSONDecodeError:
133
+ # If not a valid JSON list, fall back to comma splitting as the final string strategy
134
+ processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
135
+ except Exception as e_json_other:
136
+ print(f"ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
137
+ processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()] # Fallback
138
+
139
+ temp_processed_list_for_row = processed_str_tags
140
+
141
+ # 5. Fallback for other scalar types (e.g., int, float that are not NaN)
142
+ else:
143
+ # This path is for non-list, non-ndarray, non-None/NaN, non-string types.
144
+ # Or for NaNs that slipped through if they are not None or pd.NA (e.g. float('nan'))
145
+ if pd.isna(tags_value_from_series): # Catch any remaining NaNs like float('nan')
146
+ temp_processed_list_for_row = []
147
+ else:
148
+ str_val = str(tags_value_from_series).strip()
149
+ temp_processed_list_for_row = [str_val] if str_val else []
150
+
151
+ processed_tags_accumulator.append(temp_processed_list_for_row)
152
+
153
+ except Exception as e_outer_tag_proc:
154
+ print(f"CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
155
+ processed_tags_accumulator.append([])
156
+
157
+ return processed_tags_accumulator
158
+
159
+ def get_file_size_category(file_size_gb_val):
160
+ try:
161
+ numeric_file_size_gb = float(file_size_gb_val)
162
+ if pd.isna(numeric_file_size_gb): numeric_file_size_gb = 0.0
163
+ except (ValueError, TypeError): numeric_file_size_gb = 0.0
164
+ if 0 <= numeric_file_size_gb < 1: return "Small (<1GB)"
165
+ elif 1 <= numeric_file_size_gb < 5: return "Medium (1-5GB)"
166
+ elif 5 <= numeric_file_size_gb < 20: return "Large (5-20GB)"
167
+ elif 20 <= numeric_file_size_gb < 50: return "X-Large (20-50GB)"
168
+ elif numeric_file_size_gb >= 50: return "XX-Large (>50GB)"
169
+ else: return "Small (<1GB)"
170
+
171
+
172
+ def main_preprocessor():
173
+ print(f"Starting pre-processing script. Output: '{PROCESSED_PARQUET_FILE_PATH}'.")
174
+ overall_start_time = time.time()
175
+
176
+ print(f"Fetching fresh data from Hugging Face: {HF_PARQUET_URL}")
177
+ try:
178
+ fetch_start_time = time.time()
179
+ query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')"
180
+ df_raw = duckdb.sql(query).df()
181
+ data_download_timestamp = pd.Timestamp.now(tz='UTC')
182
+
183
+ if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
184
+ if 'id' not in df_raw.columns: raise ValueError("Fetched data must contain 'id' column.")
185
+
186
+ print(f"Fetched data in {time.time() - fetch_start_time:.2f}s. Rows: {len(df_raw)}. Downloaded at: {data_download_timestamp.strftime('%Y-%m-%d %H:%M:%S %Z')}")
187
+ except Exception as e_fetch:
188
+ print(f"ERROR: Could not fetch data from Hugging Face: {e_fetch}.")
189
+ return
190
+
191
+ df = pd.DataFrame()
192
+ print("Processing raw data...")
193
+ proc_start = time.time()
194
+
195
+ expected_cols_setup = {
196
+ 'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
197
+ 'pipeline_tag': str, 'tags': object, 'safetensors': object
198
+ }
199
+ for col_name, target_dtype in expected_cols_setup.items():
200
+ if col_name in df_raw.columns:
201
+ df[col_name] = df_raw[col_name]
202
+ if target_dtype == float: df[col_name] = pd.to_numeric(df[col_name], errors='coerce').fillna(0.0)
203
+ elif target_dtype == str: df[col_name] = df[col_name].astype(str).fillna('')
204
+ else:
205
+ if col_name in ['downloads', 'downloadsAllTime', 'likes']: df[col_name] = 0.0
206
+ elif col_name == 'pipeline_tag': df[col_name] = ''
207
+ elif col_name == 'tags': df[col_name] = pd.Series([[] for _ in range(len(df_raw))]) # Initialize with empty lists
208
+ elif col_name == 'safetensors': df[col_name] = None # Initialize with None
209
+ elif col_name == 'id': print("CRITICAL ERROR: 'id' column missing."); return
210
+
211
+ output_filesize_col_name = 'params'
212
+ if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
213
+ print(f"Using pre-existing '{output_filesize_col_name}' column as file size in GB.")
214
+ df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
215
+ elif 'safetensors' in df.columns:
216
+ print(f"Calculating '{output_filesize_col_name}' (file size in GB) from 'safetensors' data...")
217
+ df[output_filesize_col_name] = df['safetensors'].apply(extract_model_file_size_gb)
218
+ df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
219
+ else:
220
+ print(f"Cannot determine file size. Setting '{output_filesize_col_name}' to 0.0.")
221
+ df[output_filesize_col_name] = 0.0
222
+
223
+ df['data_download_timestamp'] = data_download_timestamp
224
+ print(f"Added 'data_download_timestamp' column.")
225
+
226
+ print("Categorizing models by file size...")
227
+ df['size_category'] = df[output_filesize_col_name].apply(get_file_size_category)
228
+
229
+ print("Standardizing 'tags' column...")
230
+ df['tags'] = process_tags_for_series(df['tags']) # This now uses tqdm internally
231
+
232
+ # --- START DEBUGGING BLOCK ---
233
+ # This block will execute before the main tag processing loop
234
+ if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values: # Check if ID exists
235
+ print(f"\n--- Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---")
236
+
237
+ # 1. Check the 'tags' column content after process_tags_for_series
238
+ model_specific_tags_list = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]
239
+ print(f"1. Tags from df['tags'] (after process_tags_for_series): {model_specific_tags_list}")
240
+ print(f" Type of tags: {type(model_specific_tags_list)}")
241
+ if isinstance(model_specific_tags_list, list):
242
+ for i, tag_item in enumerate(model_specific_tags_list):
243
+ print(f" Tag item {i}: '{tag_item}' (type: {type(tag_item)}, len: {len(str(tag_item))})")
244
+ # Detailed check for 'robotics' specifically
245
+ if 'robotics' in str(tag_item).lower():
246
+ print(f" DEBUG: Found 'robotics' substring in '{tag_item}'")
247
+ print(f" - str(tag_item).lower().strip(): '{str(tag_item).lower().strip()}'")
248
+ print(f" - Is it exactly 'robotics'?: {str(tag_item).lower().strip() == 'robotics'}")
249
+ print(f" - Ordinals: {[ord(c) for c in str(tag_item)]}")
250
+
251
+ # 2. Simulate temp_tags_joined for this specific model
252
+ if isinstance(model_specific_tags_list, list):
253
+ simulated_temp_tags_joined = '~~~'.join(str(t).lower().strip() for t in model_specific_tags_list if pd.notna(t) and str(t).strip())
254
+ else:
255
+ simulated_temp_tags_joined = ''
256
+ print(f"2. Simulated 'temp_tags_joined' for this model: '{simulated_temp_tags_joined}'")
257
+
258
+ # 3. Simulate 'has_robot' check for this model
259
+ robot_keywords = ['robot', 'robotics']
260
+ robot_pattern = '|'.join(robot_keywords)
261
+ manual_robot_check = bool(re.search(robot_pattern, simulated_temp_tags_joined, flags=re.IGNORECASE))
262
+ print(f"3. Manual regex check for 'has_robot' ('{robot_pattern}' in '{simulated_temp_tags_joined}'): {manual_robot_check}")
263
+ print(f"--- End Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---\n")
264
+ elif MODEL_ID_TO_DEBUG:
265
+ print(f"DEBUG: Model ID '{MODEL_ID_TO_DEBUG}' not found in DataFrame for pre-loop debugging.")
266
+ # --- END DEBUGGING BLOCK ---
267
+
268
+
269
+ print("Vectorized creation of cached tag columns...")
270
+ tag_time = time.time()
271
+ # This is the original temp_tags_joined creation:
272
+ df['temp_tags_joined'] = df['tags'].apply(
273
+ lambda tl: '~~~'.join(str(t).lower().strip() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
274
+ )
275
+
276
+ tag_map = {
277
+ 'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
278
+ 'has_robot': ['robot', 'robotics','openvla','vla'],
279
+ 'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
280
+ 'has_series': ['series', 'time-series', 'timeseries'],
281
+ 'has_video': ['video'], 'has_image': ['image', 'vision'],
282
+ 'has_text': ['text', 'nlp', 'llm']
283
+ }
284
+ for col, kws in tag_map.items():
285
+ pattern = '|'.join(kws)
286
+ df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
287
+
288
+ df['has_science'] = (
289
+ df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
290
+ ~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
291
+ )
292
+ del df['temp_tags_joined'] # Clean up temporary column
293
+ df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
294
+ df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
295
+ df['is_biomed'] = df['has_bio'] | df['has_med']
296
+ print(f"Vectorized tag columns created in {time.time() - tag_time:.2f}s.")
297
+
298
+ # --- POST-LOOP DIAGNOSTIC for has_robot & a specific model ---
299
+ if 'has_robot' in df.columns:
300
+ print("\n--- 'has_robot' Diagnostics (Preprocessor - Post-Loop) ---")
301
+ print(df['has_robot'].value_counts(dropna=False))
302
+
303
+ if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values:
304
+ model_has_robot_val = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'has_robot'].iloc[0]
305
+ print(f"Value of 'has_robot' for model '{MODEL_ID_TO_DEBUG}': {model_has_robot_val}")
306
+ if model_has_robot_val:
307
+ print(f" Original tags for '{MODEL_ID_TO_DEBUG}': {df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]}")
308
+
309
+ if df['has_robot'].any():
310
+ print("Sample models flagged as 'has_robot':")
311
+ print(df[df['has_robot']][['id', 'tags', 'has_robot']].head(5))
312
+ else:
313
+ print("No models were flagged as 'has_robot' after processing.")
314
+ print("--------------------------------------------------------\n")
315
+ # --- END POST-LOOP DIAGNOSTIC ---
316
+
317
+
318
+ print("Adding organization column...")
319
+ df['organization'] = df['id'].apply(extract_org_from_id)
320
+
321
+ # Drop safetensors if params was calculated from it, and params didn't pre-exist as numeric
322
+ if 'safetensors' in df.columns and \
323
+ not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
324
+ df = df.drop(columns=['safetensors'], errors='ignore')
325
+
326
+ final_expected_cols = [
327
+ 'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags',
328
+ 'params', 'size_category', 'organization',
329
+ 'has_audio', 'has_speech', 'has_music', 'has_robot', 'has_bio', 'has_med',
330
+ 'has_series', 'has_video', 'has_image', 'has_text', 'has_science',
331
+ 'is_audio_speech', 'is_biomed',
332
+ 'data_download_timestamp'
333
+ ]
334
+ # Ensure all final columns exist, adding defaults if necessary
335
+ for col in final_expected_cols:
336
+ if col not in df.columns:
337
+ print(f"Warning: Final expected column '{col}' is missing! Defaulting appropriately.")
338
+ if col == 'params': df[col] = 0.0
339
+ elif col == 'size_category': df[col] = "Small (<1GB)" # Default size category
340
+ elif 'has_' in col or 'is_' in col : df[col] = False # Default boolean flags to False
341
+ elif col == 'data_download_timestamp': df[col] = pd.NaT # Default timestamp to NaT
342
+
343
+ print(f"Data processing completed in {time.time() - proc_start:.2f}s.")
344
+ try:
345
+ print(f"Saving processed data to: {PROCESSED_PARQUET_FILE_PATH}")
346
+ df_to_save = df[final_expected_cols].copy() # Ensure only expected columns are saved
347
+ df_to_save.to_parquet(PROCESSED_PARQUET_FILE_PATH, index=False, engine='pyarrow')
348
+ print(f"Successfully saved processed data.")
349
+ except Exception as e_save:
350
+ print(f"ERROR: Could not save processed data: {e_save}")
351
+ return
352
+
353
+ total_elapsed_script = time.time() - overall_start_time
354
+ print(f"Pre-processing finished. Total time: {total_elapsed_script:.2f}s. Final Parquet shape: {df_to_save.shape}")
355
+
356
+ if __name__ == "__main__":
357
+ if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
358
+ print(f"Deleting existing '{PROCESSED_PARQUET_FILE_PATH}' to ensure fresh processing...")
359
+ try: os.remove(PROCESSED_PARQUET_FILE_PATH)
360
+ except OSError as e: print(f"Error deleting file: {e}. Please delete manually and rerun."); exit()
361
+
362
+ main_preprocessor()
363
+
364
+ if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
365
+ print(f"\nTo verify, load parquet and check 'has_robot' and its 'tags':")
366
+ print(f"import pandas as pd; df_chk = pd.read_parquet('{PROCESSED_PARQUET_FILE_PATH}')")
367
+ print(f"print(df_chk['has_robot'].value_counts())")
368
+ if MODEL_ID_TO_DEBUG:
369
+ print(f"print(df_chk[df_chk['id'] == '{MODEL_ID_TO_DEBUG}'][['id', 'tags', 'has_robot']])")
370
+ else:
371
+ print(f"print(df_chk[df_chk['has_robot']][['id', 'tags', 'has_robot']].head())")