Spaces:
Running
on
Zero
Running
on
Zero
File size: 9,530 Bytes
570eaa9 a528449 570eaa9 2af55e5 570eaa9 a528449 570eaa9 a528449 570eaa9 a528449 570eaa9 a528449 570eaa9 a528449 570eaa9 a528449 570eaa9 a528449 570eaa9 a528449 570eaa9 2af55e5 a528449 2af55e5 a528449 2af55e5 a528449 2af55e5 a528449 2af55e5 a528449 2af55e5 570eaa9 a528449 570eaa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
import os
import gradio as gr
import torch
import itertools # Import itertools for color cycling
from bytelatent.data.file_util import get_fs
from bytelatent.generate_patcher import patcher_nocache
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
from bytelatent.args import TrainArgs
from download_blt_weights import main as ensure_present
# --- Global Setup (Consider loading models outside if necessary) ---
# Kept inside the function for simplicity as before.
# Define colors for patches (similar to the image style)
# Using colors from a qualitative colormap (e.g., Colorbrewer Set3 or Paired)
PATCH_COLORS = [
"#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c",
"#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#ffff99", "#b15928"
] # Add more if you expect many patches
def create_highlighted_text_data(tokenizer, patch_lengths_tensor, tokens_tensor, colors):
"""
Generates the data structure needed for gr.HighlightedText based on patches.
Args:
tokenizer: The BltTokenizer instance.
patch_lengths_tensor: Tensor containing the length of each patch (in tokens).
tokens_tensor: Tensor containing the token IDs for the entire sequence.
colors: A list of color hex codes to cycle through.
Returns:
A list of tuples for gr.HighlightedText, e.g., [(text, label), ...].
Returns None if input tensors are invalid.
"""
if patch_lengths_tensor is None or tokens_tensor is None or patch_lengths_tensor.numel() == 0:
return None
patch_lengths = patch_lengths_tensor.tolist()
all_tokens = tokens_tensor.tolist()
highlighted_data = []
current_token_index = 0
color_cycler = itertools.cycle(colors) # Use itertools to cycle through colors
for i, length in enumerate(patch_lengths):
if length <= 0: # Skip empty patches if they somehow occur
continue
patch_token_ids = all_tokens[current_token_index : current_token_index + length]
if not patch_token_ids: # Should not happen if length > 0, but good practice
continue
patch_text = tokenizer.decode(patch_token_ids)
patch_label = f"Patch {i+1}" # Unique label for each patch
patch_color = next(color_cycler) # Get the next color
# Add to highlighted_data: (text, label_for_coloring)
highlighted_data.append((patch_text, patch_label))
current_token_index += length
# Check if all tokens were consumed (optional sanity check)
if current_token_index != len(all_tokens):
print(f"Warning: Token mismatch. Consumed {current_token_index}, total {len(all_tokens)}")
# Decode any remaining tokens if necessary, though this indicates a logic issue
remaining_tokens = all_tokens[current_token_index:]
if remaining_tokens:
remaining_text = tokenizer.decode(remaining_tokens)
highlighted_data.append((remaining_text, "Remainder")) # Assign a generic label
return highlighted_data
def process_text(prompt: str, model_name: str = "blt-1b"):
"""
Processes the input prompt using the ByteLatent model and returns
an entropy plot and color-coded text data.
Args:
prompt: The input text string from the Gradio interface.
model_name: The name of the model to use.
Returns:
A tuple containing:
- Matplotlib Figure for the entropy plot (or None on error).
- List of tuples for gr.HighlightedText (or None on error/no results).
- Error message string (or None if successful).
"""
try:
# --- Model and Tokenizer Loading ---
consolidated_path = os.path.join("hf-weights", model_name)
train_args_path = os.path.join(consolidated_path, "params.json")
if not os.path.exists(train_args_path):
raise FileNotFoundError(f"Training args not found at {train_args_path}. "
f"Ensure model '{model_name}' is downloaded/available.")
fs = get_fs(train_args_path)
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
tokenizer = train_args.data.tokenizer_args.build()
assert isinstance(tokenizer, BltTokenizer)
patcher_args = train_args.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = True
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
patcher_args.patching_device = device
patcher_args.device = device
print("Loading entropy model and patcher...")
entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
if not os.path.exists(entropy_model_dir):
raise FileNotFoundError(f"Entropy model directory not found at {entropy_model_dir}.")
patcher_args.entropy_model_checkpoint_dir = entropy_model_dir
patcher = patcher_args.build()
# --- End Loading ---
# --- Processing ---
prompts = [prompt]
print(f"Processing prompt: '{prompt}'")
results = patcher_nocache(
prompts, tokenizer=tokenizer, patcher=patcher
)
if not results:
print("Processing returned no results.")
return None, None, "Processing completed, but no results were generated."
batch_patch_lengths, batch_scores, batch_tokens = results
# Process the first (and only) result in the batch
patch_lengths = batch_patch_lengths[0]
scores = batch_scores[0]
tokens = batch_tokens[0]
# Decode the full output once for the plot labels (if needed by plot_entropies)
# Note: BltTokenizer might decode directly to bytes, then utf-8. Ensure it handles errors.
try:
# Using the raw tokens tensor for decoding consistency
decoded_output_for_plot = tokenizer.decode(tokens.tolist())
except Exception as decode_err:
print(f"Warning: Error decoding full sequence for plot: {decode_err}")
# Fallback: attempt to decode the original prompt if possible, or use generic labels
decoded_output_for_plot = prompt # Use original prompt as fallback
# Generate the plot
fig = plot_entropies(
patch_lengths,
scores,
decoded_output_for_plot, # Pass the decoded string for plot labels
threshold=patcher.threshold
)
# Generate data for HighlightedText
highlighted_data = create_highlighted_text_data(
tokenizer, patch_lengths, tokens, PATCH_COLORS
)
print("Processing and visualization data generation complete.")
# --- End Processing ---
return fig, highlighted_data, None # Return plot, highlighted text data, no error
except FileNotFoundError as e:
print(f"Error: {e}")
return None, None, f"Error: {str(e)}" # Return None for plot/text, error message
except Exception as e:
print(f"An unexpected error occurred: {e}")
import traceback
traceback.print_exc()
return None, None, f"An unexpected error occurred: {e}" # Return None for plot/text, error message
# --- Gradio Interface ---
# Create the color map for HighlightedText dynamically
# Generate enough patch labels and map them to the cycled colors
MAX_EXPECTED_PATCHES = 50 # Estimate a reasonable maximum
color_map = {
f"Patch {i+1}": color
for i, color in zip(range(MAX_EXPECTED_PATCHES), itertools.cycle(PATCH_COLORS))
}
# Add a color for the potential 'Remainder' label from create_highlighted_text_data
color_map["Remainder"] = "#808080" # Grey for any leftovers
with gr.Blocks() as iface:
gr.Markdown("# ByteLatent Entropy Visualizer") # Title
gr.Markdown(
"Process any prompt (limited to 512 bytes) with the 100M entropy patcher model "
"and visualize the token entropies plot and color-coded patches below.<br><br>" # Updated description
"NOTE: this implementation differs slightly by excluding local attention so we limit "
"the characters limit to 512 to avoid any deviation.",
line_breaks=True
)
with gr.Column():
prompt_input = gr.Textbox(
label="Input Prompt",
value="Daenerys Targaryen is in Game of Thrones, a fantasy epic by George R.R. Martin.",
placeholder="Enter text here...",
max_length=512,
lines=3
)
submit_button = gr.Button("Generate Visualization") # Update button text
# Output for error messages or status
status_output = gr.Textbox(label="Status", interactive=False)
# Output component for the color-coded text
highlighted_output = gr.HighlightedText(
label="Patched Text Visualization",
color_map=color_map,
show_legend=False # Show the patch labels and colors
)
# Output component for the plot
plot_output = gr.Plot(label="Entropy vs. Token Index (with Patch Threshold)")
# Define the action for the button click
submit_button.click(
fn=process_text,
inputs=prompt_input,
outputs=[plot_output, highlighted_output, status_output] # Order matters!
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
ensure_present(["blt-1b"]) # Ensure model is present before launching
iface.launch()
|