File size: 9,794 Bytes
1613c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4360f38
 
1613c54
 
 
e82b449
1613c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0d7f58
1613c54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# app.py for Hugging Face Space
# Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'),
# and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository.
# Using gr.DataFrame does not require adding pandas if using list-of-lists format.
from huggingface_hub import login

import gradio as gr
import torch # Or tensorflow/flax depending on backend
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download # Import hub download function
import json # Import json library
import os # Import os library for path joining


# --- Configuration ---
MODEL_NAME = "google/txgemma-2b-predict"
#MODEL_NAME = "google/txgemma-9b-predict"
PROMPT_FILENAME = "tdc_prompts.json"
MODEL_CACHE = "model_cache" # Optional: define a cache directory
# MAX_EXAMPLES is no longer strictly limiting the display, but can be used if needed later
MAX_EXAMPLES = 600 # Keep variable definition, but DataFrame handles scrolling
EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene)
DATAFRAME_HEADERS = ["Task Name", "Prompt Template"]
DATAFRAME_ROW_COUNT = 8 # Number of rows to display initially in the DataFrame

hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)


# --- Load Model, Tokenizer, and Prompts ---
print(f"Loading model: {MODEL_NAME}...")
tdc_prompts_data = None # Initialize as None
dataframe_data = [] # Initialize empty list for DataFrame content
try:
    # Check if GPU is available and use it, otherwise use CPU
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE)
    print("Tokenizer loaded.")

    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        cache_dir=MODEL_CACHE,
        device_map="auto" # Automatically distribute model across available devices (GPU/CPU)
    )
    print("Model loaded.")

    # Download and load the prompts JSON file
    print(f"Downloading {PROMPT_FILENAME}...")
    prompts_file_path = hf_hub_download(
        repo_id=MODEL_NAME,
        filename=PROMPT_FILENAME,
        cache_dir=MODEL_CACHE,
    )
    print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}")

    # Load the JSON data
    with open(prompts_file_path, 'r') as f:
        tdc_prompts_data = json.load(f)
    print(f"Loaded prompts data from {PROMPT_FILENAME}.")

    # --- Prepare data for Gradio DataFrame ---
    # Updated logic: Parse the dictionary format from tdc_prompts.json
    # Create a list of lists for the DataFrame: [[task_name, prompt_template], ...]
    if isinstance(tdc_prompts_data, dict):
        print(f"Processing {len(tdc_prompts_data)} prompts from dictionary for DataFrame...")
        for task_name, prompt_template in tdc_prompts_data.items():
            if isinstance(prompt_template, str) and isinstance(task_name, str):
                # Add task name and the raw template to the list
                dataframe_data.append([task_name, prompt_template])
            else:
                print(f"Warning: Skipping invalid item in prompts dictionary: key={task_name}, value_type={type(prompt_template)}")
        print(f"Prepared {len(dataframe_data)} rows for DataFrame.")

    else:
        print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.")
        # dataframe_data remains empty


except Exception as e:
    print(f"Error loading model, tokenizer, or prompts: {e}")
    # Ensure dataframe_data is empty on error during setup
    dataframe_data = []
    raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}")


# --- Prediction Function ---
def predict(prompt, max_new_tokens=100, temperature=0.7):
    """
    Generates text based on the input prompt using the loaded model.
    (Function remains the same as before)
    """
    print(f"Received prompt: {prompt}")
    print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}")

    try:
        # Prepare the input for the model
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device

        # Generate text
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=int(max_new_tokens), # Ensure it's an integer
                temperature=float(temperature),   # Ensure it's a float
                do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0
                pad_token_id=tokenizer.eos_token_id # Set pad token id
            )

        # Decode the generated tokens
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Generated text (raw): {generated_text}")

        # Remove the prompt from the beginning of the generated text
        if generated_text.startswith(prompt):
            prompt_length = len(prompt)
            result_text = generated_text[prompt_length:].lstrip()
        else:
             common_prefix = os.path.commonprefix([prompt, generated_text])
             if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8:
                 result_text = generated_text[len(common_prefix):].lstrip()
             else:
                 result_text = generated_text

        print(f"Generated text (processed): {result_text}")
        return result_text

    except Exception as e:
        print(f"Error during prediction: {e}")
        return f"An error occurred during generation: {e}"

# --- Function to handle DataFrame selection ---
def select_prompt_from_df(evt: gr.SelectData):
    """
    Triggered when a row is selected in the DataFrame.
    Updates the main prompt input with the selected template, replacing the placeholder.
    """
    if evt.index is None or evt.index[0] >= len(dataframe_data):
         print("Invalid selection event or index out of bounds.")
         return gr.update() # No change

    selected_row_index = evt.index[0]
    # Get the prompt template from the second column (index 1) of the selected row
    prompt_template = dataframe_data[selected_row_index][1]

    # Replace the placeholder with the example SMILES string
    selected_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES)
    print(f"Selected prompt template from row {selected_row_index}, updated input.")

    # Return the processed prompt to update the prompt_input textbox
    return selected_prompt


# --- Gradio Interface ---
print("Creating Gradio interface...")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        f"""
        # 🤖 TXGemma-2B-Predict Property Prediction

        Enter a prompt below, or select a task from the table to load its template, and the model ({MODEL_NAME}) will generate text.
        Adjust the parameters for different results. Prompt templates loaded from `{PROMPT_FILENAME}`.
        Selected templates will use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder.
        """
    )
    with gr.Row():
        with gr.Column(scale=2):
            prompt_input = gr.Textbox(
                label="Your Prompt",
                placeholder="Enter your text prompt here, or select a template from the table below...",
                lines=5,
                elem_id="prompt_input_box" # Add elem_id for clarity if needed
            )
            with gr.Row():
                 max_tokens_slider = gr.Slider(
                    minimum=10,
                    maximum=500,
                    value=100,
                    step=10,
                    label="Max New Tokens",
                    info="Maximum number of tokens to generate after the prompt."
                 )
                 temperature_slider = gr.Slider(
                    minimum=0.0,
                    maximum=1.5,
                    value=0.7,
                    step=0.05,
                    label="Temperature",
                    info="Controls randomness (0=deterministic, >0=random)."
                 )
            submit_button = gr.Button("Generate Text", variant="primary")
        with gr.Column(scale=3):
            output_text = gr.Textbox(
                label="Generated Text",
                lines=10, # Adjust height if needed
                interactive=False
            )

    # --- Add DataFrame for Prompt Templates ---
    gr.Markdown("### Select a Prompt Template")
    prompt_df = gr.DataFrame(
        value=dataframe_data,
        headers=DATAFRAME_HEADERS,
        row_count=(DATAFRAME_ROW_COUNT, "dynamic"), # Show fixed rows initially, allow scrolling
        col_count=(len(DATAFRAME_HEADERS), "fixed"), # Fixed number of columns
        wrap=True, # Wrap text in cells
        label="Prompt Templates"
    )

    # --- Connect Components ---
    # Connect submit button to prediction function
    submit_button.click(
        fn=predict,
        inputs=[prompt_input, max_tokens_slider, temperature_slider],
        outputs=output_text,
        api_name="predict"
    )

    # Connect DataFrame selection to update prompt input
    # The `select` event triggers the `select_prompt_from_df` function.
    # The event data (evt: gr.SelectData) is implicitly passed to the function.
    # The function returns the value to update the `prompt_input` component.
    prompt_df.select(
        fn=select_prompt_from_df,
        inputs=None, # No explicit inputs needed, event data is passed automatically
        outputs=prompt_input,
        show_progress="hidden" # Hide progress bar for this quick update
    )


# --- Launch the App ---
print("Launching Gradio app...")
demo.queue().launch(debug=True) # Set debug=False for production