# /// script # requires-python = ">=3.10" # dependencies = [ # "datasets", # "flashinfer-python", # "hf_transfer", # "huggingface-hub[hf_xet]", # "polars", # "stamina", # "transformers", # "vllm", # "tqdm", # "setuptools", # ] # /// import argparse import logging import os import sys from typing import Optional # Set environment variables to speed up model loading os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["VLLM_ATTENTION_BACKEND"] = "FLASHINFER" import polars as pl from datasets import Dataset, load_dataset from huggingface_hub import login, dataset_info, snapshot_download from tqdm.auto import tqdm from transformers import AutoTokenizer from vllm import LLM, SamplingParams # Setup logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) def format_prompt(content: str, card_type: str, tokenizer) -> str: """Format content as a prompt for the model.""" if card_type == "model": messages = [{"role": "user", "content": f"{content[:4000]}"}] else: messages = [{"role": "user", "content": f"{content[:4000]}"}] return tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False ) def load_and_filter_data( dataset_id: str, card_type: str, min_likes: int = 1, min_downloads: int = 1 ) -> pl.DataFrame: """Load and filter dataset/model data.""" logger.info(f"Loading data from {dataset_id}") ds = load_dataset(dataset_id, split="train") df = ds.to_polars().lazy() # Extract content after YAML frontmatter df = df.with_columns( [ pl.col("card") .str.replace_all(r"^---\n[\s\S]*?\n---\n", "", literal=False) .str.strip_chars() .alias("post_yaml_content") ] ) # Apply filters df = df.filter(pl.col("post_yaml_content").str.len_bytes() > 200) df = df.filter(pl.col("post_yaml_content").str.len_bytes() < 120_000) if card_type == "model": df = df.filter(pl.col("likes") >= min_likes) df = df.filter(pl.col("downloads") >= min_downloads) df_filtered = df.collect() logger.info(f"Filtered dataset has {len(df_filtered)} items") return df_filtered def generate_summaries( model_id: str, input_dataset_id: str, output_dataset_id: str, card_type: str = "dataset", max_tokens: int = 120, temperature: float = 0.6, batch_size: int = 1000, min_likes: int = 1, min_downloads: int = 1, hf_token: Optional[str] = None, ): """Main function to generate summaries.""" # Login if token provided HF_TOKEN = hf_token or os.environ.get("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN) # Load and filter data df_filtered = load_and_filter_data( input_dataset_id, card_type, min_likes, min_downloads ) # Download model to local directory first logger.info(f"Downloading model {model_id} to local directory...") local_model_path = snapshot_download(repo_id=model_id, resume_download=True) logger.info(f"Model downloaded to: {local_model_path}") # Initialize model and tokenizer from local path logger.info(f"Initializing vLLM model from local path: {local_model_path}") llm = LLM(model=local_model_path) tokenizer = AutoTokenizer.from_pretrained(local_model_path) sampling_params = SamplingParams( temperature=temperature, max_tokens=max_tokens, ) # Prepare prompts logger.info("Preparing prompts") post_yaml_contents = df_filtered["post_yaml_content"].to_list() prompts = [ format_prompt(content, card_type, tokenizer) for content in tqdm(post_yaml_contents, desc="Formatting prompts") ] # Generate summaries in batches logger.info(f"Generating summaries for {len(prompts)} items") all_outputs = [] for i in tqdm(range(0, len(prompts), batch_size), desc="Generating summaries"): batch_prompts = prompts[i : i + batch_size] outputs = llm.generate(batch_prompts, sampling_params) all_outputs.extend(outputs) # Extract clean results clean_results = [output.outputs[0].text.strip() for output in all_outputs] # Create dataset and add summaries ds = Dataset.from_polars(df_filtered) ds = ds.add_column("summary", clean_results) # Push to hub logger.info(f"Pushing dataset to hub: {output_dataset_id}") ds.push_to_hub(output_dataset_id, token=HF_TOKEN) logger.info("Dataset successfully pushed to hub") def main(): parser = argparse.ArgumentParser( description="Generate summaries for Hugging Face datasets or models using vLLM" ) parser.add_argument( "model_id", help="Model ID for summary generation (e.g., davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02)", ) parser.add_argument( "input_dataset_id", help="Input dataset ID (e.g., librarian-bots/dataset_cards_with_metadata)", ) parser.add_argument( "output_dataset_id", help="Output dataset ID where results will be saved" ) parser.add_argument( "--card-type", choices=["dataset", "model"], default="dataset", help="Type of cards to process (default: dataset)", ) parser.add_argument( "--max-tokens", type=int, default=120, help="Maximum tokens for summary generation (default: 120)", ) parser.add_argument( "--temperature", type=float, default=0.6, help="Temperature for generation (default: 0.6)", ) parser.add_argument( "--batch-size", type=int, default=1000, help="Batch size for processing (default: 1000)", ) parser.add_argument( "--min-likes", type=int, default=1, help="Minimum likes filter for models (default: 1)", ) parser.add_argument( "--min-downloads", type=int, default=1, help="Minimum downloads filter for models (default: 1)", ) parser.add_argument( "--hf-token", help="Hugging Face token (uses HF_TOKEN env var if not provided)" ) args = parser.parse_args() generate_summaries( model_id=args.model_id, input_dataset_id=args.input_dataset_id, output_dataset_id=args.output_dataset_id, card_type=args.card_type, max_tokens=args.max_tokens, temperature=args.temperature, batch_size=args.batch_size, min_likes=args.min_likes, min_downloads=args.min_downloads, hf_token=args.hf_token, ) if __name__ == "__main__": if len(sys.argv) == 1: # Show example hfjobs command when run without arguments print("Example hfjobs command:") print( "hfjobs run --flavor l4x1 --secret HF_TOKEN=hf_*** ghcr.io/astral-sh/uv:debian /bin/bash -c '" ) print("apt-get update && apt-get install -y python3-dev gcc && \\") print("export HOME=/tmp && \\") print("export USER=dummy && \\") print("export TORCHINDUCTOR_CACHE_DIR=/tmp/torch-inductor && \\") print("uv run generate_summaries_uv.py \\") print(" davanstrien/Smol-Hub-tldr \\") print(" librarian-bots/dataset_cards_with_metadata \\") print(" your-username/datasets_with_summaries \\") print(" --card-type dataset \\") print(" --batch-size 2000") print("' --project summary-generation --name dataset-summaries") print() print("For models:") print("uv run generate_summaries_uv.py \\") print(" davanstrien/SmolLM2-135M-tldr-sft-2025-03-12_19-02 \\") print(" librarian-bots/model_cards_with_metadata \\") print(" your-username/models_with_summaries \\") print(" --card-type model \\") print(" --min-likes 5 \\") print(" --min-downloads 1000") else: main()