File size: 9,794 Bytes
1613c54 4360f38 1613c54 e82b449 1613c54 a0d7f58 1613c54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'),
# and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository.
# Using gr.DataFrame does not require adding pandas if using list-of-lists format.
from huggingface_hub import login
import gradio as gr
import torch # Or tensorflow/flax depending on backend
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download # Import hub download function
import json # Import json library
import os # Import os library for path joining
# --- Configuration ---
MODEL_NAME = "google/txgemma-2b-predict"
#MODEL_NAME = "google/txgemma-9b-predict"
PROMPT_FILENAME = "tdc_prompts.json"
MODEL_CACHE = "model_cache" # Optional: define a cache directory
# MAX_EXAMPLES is no longer strictly limiting the display, but can be used if needed later
MAX_EXAMPLES = 600 # Keep variable definition, but DataFrame handles scrolling
EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene)
DATAFRAME_HEADERS = ["Task Name", "Prompt Template"]
DATAFRAME_ROW_COUNT = 8 # Number of rows to display initially in the DataFrame
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# --- Load Model, Tokenizer, and Prompts ---
print(f"Loading model: {MODEL_NAME}...")
tdc_prompts_data = None # Initialize as None
dataframe_data = [] # Initialize empty list for DataFrame content
try:
# Check if GPU is available and use it, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE)
print("Tokenizer loaded.")
# Load the model
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=MODEL_CACHE,
device_map="auto" # Automatically distribute model across available devices (GPU/CPU)
)
print("Model loaded.")
# Download and load the prompts JSON file
print(f"Downloading {PROMPT_FILENAME}...")
prompts_file_path = hf_hub_download(
repo_id=MODEL_NAME,
filename=PROMPT_FILENAME,
cache_dir=MODEL_CACHE,
)
print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}")
# Load the JSON data
with open(prompts_file_path, 'r') as f:
tdc_prompts_data = json.load(f)
print(f"Loaded prompts data from {PROMPT_FILENAME}.")
# --- Prepare data for Gradio DataFrame ---
# Updated logic: Parse the dictionary format from tdc_prompts.json
# Create a list of lists for the DataFrame: [[task_name, prompt_template], ...]
if isinstance(tdc_prompts_data, dict):
print(f"Processing {len(tdc_prompts_data)} prompts from dictionary for DataFrame...")
for task_name, prompt_template in tdc_prompts_data.items():
if isinstance(prompt_template, str) and isinstance(task_name, str):
# Add task name and the raw template to the list
dataframe_data.append([task_name, prompt_template])
else:
print(f"Warning: Skipping invalid item in prompts dictionary: key={task_name}, value_type={type(prompt_template)}")
print(f"Prepared {len(dataframe_data)} rows for DataFrame.")
else:
print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.")
# dataframe_data remains empty
except Exception as e:
print(f"Error loading model, tokenizer, or prompts: {e}")
# Ensure dataframe_data is empty on error during setup
dataframe_data = []
raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}")
# --- Prediction Function ---
def predict(prompt, max_new_tokens=100, temperature=0.7):
"""
Generates text based on the input prompt using the loaded model.
(Function remains the same as before)
"""
print(f"Received prompt: {prompt}")
print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}")
try:
# Prepare the input for the model
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device
# Generate text
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens), # Ensure it's an integer
temperature=float(temperature), # Ensure it's a float
do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0
pad_token_id=tokenizer.eos_token_id # Set pad token id
)
# Decode the generated tokens
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated text (raw): {generated_text}")
# Remove the prompt from the beginning of the generated text
if generated_text.startswith(prompt):
prompt_length = len(prompt)
result_text = generated_text[prompt_length:].lstrip()
else:
common_prefix = os.path.commonprefix([prompt, generated_text])
if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8:
result_text = generated_text[len(common_prefix):].lstrip()
else:
result_text = generated_text
print(f"Generated text (processed): {result_text}")
return result_text
except Exception as e:
print(f"Error during prediction: {e}")
return f"An error occurred during generation: {e}"
# --- Function to handle DataFrame selection ---
def select_prompt_from_df(evt: gr.SelectData):
"""
Triggered when a row is selected in the DataFrame.
Updates the main prompt input with the selected template, replacing the placeholder.
"""
if evt.index is None or evt.index[0] >= len(dataframe_data):
print("Invalid selection event or index out of bounds.")
return gr.update() # No change
selected_row_index = evt.index[0]
# Get the prompt template from the second column (index 1) of the selected row
prompt_template = dataframe_data[selected_row_index][1]
# Replace the placeholder with the example SMILES string
selected_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES)
print(f"Selected prompt template from row {selected_row_index}, updated input.")
# Return the processed prompt to update the prompt_input textbox
return selected_prompt
# --- Gradio Interface ---
print("Creating Gradio interface...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# 🤖 TXGemma-2B-Predict Property Prediction
Enter a prompt below, or select a task from the table to load its template, and the model ({MODEL_NAME}) will generate text.
Adjust the parameters for different results. Prompt templates loaded from `{PROMPT_FILENAME}`.
Selected templates will use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder.
"""
)
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Your Prompt",
placeholder="Enter your text prompt here, or select a template from the table below...",
lines=5,
elem_id="prompt_input_box" # Add elem_id for clarity if needed
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=10,
maximum=500,
value=100,
step=10,
label="Max New Tokens",
info="Maximum number of tokens to generate after the prompt."
)
temperature_slider = gr.Slider(
minimum=0.0,
maximum=1.5,
value=0.7,
step=0.05,
label="Temperature",
info="Controls randomness (0=deterministic, >0=random)."
)
submit_button = gr.Button("Generate Text", variant="primary")
with gr.Column(scale=3):
output_text = gr.Textbox(
label="Generated Text",
lines=10, # Adjust height if needed
interactive=False
)
# --- Add DataFrame for Prompt Templates ---
gr.Markdown("### Select a Prompt Template")
prompt_df = gr.DataFrame(
value=dataframe_data,
headers=DATAFRAME_HEADERS,
row_count=(DATAFRAME_ROW_COUNT, "dynamic"), # Show fixed rows initially, allow scrolling
col_count=(len(DATAFRAME_HEADERS), "fixed"), # Fixed number of columns
wrap=True, # Wrap text in cells
label="Prompt Templates"
)
# --- Connect Components ---
# Connect submit button to prediction function
submit_button.click(
fn=predict,
inputs=[prompt_input, max_tokens_slider, temperature_slider],
outputs=output_text,
api_name="predict"
)
# Connect DataFrame selection to update prompt input
# The `select` event triggers the `select_prompt_from_df` function.
# The event data (evt: gr.SelectData) is implicitly passed to the function.
# The function returns the value to update the `prompt_input` component.
prompt_df.select(
fn=select_prompt_from_df,
inputs=None, # No explicit inputs needed, event data is passed automatically
outputs=prompt_input,
show_progress="hidden" # Hide progress bar for this quick update
)
# --- Launch the App ---
print("Launching Gradio app...")
demo.queue().launch(debug=True) # Set debug=False for production
|