Spaces:
Sleeping
Sleeping
File size: 8,089 Bytes
89588fc b1951fe 89588fc 4afe335 89588fc 4afe335 89588fc 4afe335 b1951fe 89588fc b1951fe 89588fc 4afe335 b1951fe 4afe335 b1951fe 4afe335 89588fc 4afe335 b1951fe 4afe335 89588fc b1951fe 89588fc b1951fe 89588fc 4afe335 89588fc 4afe335 89588fc 4afe335 b1951fe 89588fc 4afe335 b1951fe 4afe335 89588fc 4afe335 89588fc 4afe335 b1951fe 4afe335 b1951fe 4afe335 89588fc 4afe335 89588fc 4afe335 89588fc b1951fe 89588fc b1951fe 89588fc b1951fe 4afe335 b1951fe 4afe335 89588fc 4afe335 89588fc b1951fe 4afe335 89588fc b1951fe 89588fc 4afe335 89588fc 4afe335 b1951fe 89588fc b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 4afe335 b1951fe 89588fc b1951fe 4afe335 b1951fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import streamlit as st
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from rdkit import Chem
from rdkit.Chem import Draw
from streamlit_ketcher import st_ketcher
import torch
# --- Page Configuration ---
st.set_page_config(
page_title="Chemical Reaction Predictor",
page_icon="🧪",
layout="wide",
initial_sidebar_state="expanded"
)
# --- Model Loading ---
# Use st.cache_resource to load the model only once
@st.cache_resource
def load_model():
"""
Loads the T5 model and tokenizer from Hugging Face.
Uses AutoModel for better compatibility.
"""
model_name = "sagawa/ReactionT5v2-forward-USPTO_MIT"
try:
# Use Auto* classes for robustness
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
return model, tokenizer
except Exception as e:
# Provide more detailed error information
st.error("An error occurred while loading the model.")
st.error(f"Error Type: {type(e).__name__}")
st.error(f"Error Details: {e}")
# Add a hint about potential memory issues on Hugging Face Spaces
st.info("Hint: Free tiers on Hugging Face Spaces have limited memory (RAM). "
"If the app fails to load the model, it might be due to an Out-of-Memory error. "
"Consider upgrading your Space for more resources.")
return None, None
# --- Core Functions ---
def predict_product(reactants, reagents, model, tokenizer, num_predictions):
"""Predicts the reaction product using the T5 model."""
# Format the input string as required by the model
# Handle the case where reagents might be empty
if reagents and reagents.strip():
input_text = f"reactants>{reactants}.reagents>{reagents}>products>"
else:
input_text = f"reactants>{reactants}>products>"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# Generate predictions using beam search
outputs = model.generate(
input_ids,
max_length=512,
num_beams=num_predictions * 2, # Generate more beams for better diversity
num_return_sequences=num_predictions,
early_stopping=True,
)
# Decode predictions
predictions = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
return predictions
def display_molecule(smiles_string, legend):
"""Generates and displays a molecule image from a SMILES string."""
if not smiles_string:
st.warning("Received an empty SMILES string.")
return
mol = Chem.MolFromSmiles(smiles_string)
if mol:
try:
img = Draw.MolToImage(mol, size=(300, 300), legend=legend)
st.image(img, use_column_width='auto')
except Exception as e:
st.warning(f"Could not generate image for SMILES: {smiles_string}. Error: {e}")
else:
st.warning(f"Invalid SMILES string provided: {smiles_string}")
# --- Initialize Session State ---
# This ensures that the state is preserved across reruns
if 'reactants' not in st.session_state:
st.session_state.reactants = "CCO.O=C(O)C" # Start with a default example
if 'reagents' not in st.session_state:
st.session_state.reagents = ""
# --- Sidebar UI ---
with st.sidebar:
st.title("🧪 Reaction Predictor")
st.markdown("---")
st.header("Controls and Information")
# Example Reactions
example_reactions = {
"Esterification": ("CCO.O=C(O)C", ""),
"Amide Formation": ("CCN.O=C(Cl)C", ""),
"Suzuki Coupling": ("[B-](C1=CC=CC=C1)(F)(F)F.[K+].CC1=CC=C(Br)C=C1", "c1ccc(B(O)O)cc1"),
"Clear Inputs": ("", "")
}
def load_example():
# Callback to load selected example into session state
example_key = st.session_state.example_select
reactants, reagents = example_reactions[example_key]
st.session_state.reactants = reactants
st.session_state.reagents = reagents
st.selectbox(
"Load an Example Reaction",
options=list(example_reactions.keys()),
key="example_select",
on_change=load_example
)
st.markdown("---")
st.subheader("Prediction Parameters")
num_predictions = st.slider("Number of Predictions to Generate", 1, 5, 1, help="How many potential products should the model suggest?")
st.markdown("---")
st.subheader("About")
st.info(
"This app uses the sagawa/ReactionT5v2-forward-USPTO_MIT model to predict chemical reaction products."
)
st.markdown("[View Model on Hugging Face](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT)")
# --- Main Application UI ---
st.title("Chemical Reaction Predictor")
st.markdown("A tool to predict chemical reactions using a state-of-the-art Transformer model.")
# --- Model Loading and Main Logic ---
with st.spinner("Loading the prediction model... This may take a moment on first startup."):
model, tokenizer = load_model()
# Only proceed if the model loaded successfully
if model and tokenizer:
st.success("Model loaded successfully!")
# Input Section
st.header("1. Provide Reactants and Reagents")
input_tab1, input_tab2 = st.tabs(["✍️ Chemical Drawing Tool", "⌨️ SMILES Text Input"])
with input_tab1:
col1, col2 = st.columns(2)
with col1:
st.subheader("Reactants")
# This component's value is now directly tied to the session state
reactant_smiles_drawing = st_ketcher(st.session_state.reactants, key="ketcher_reactants")
if reactant_smiles_drawing != st.session_state.reactants:
st.session_state.reactants = reactant_smiles_drawing
st.rerun() # Use the modern rerun command
with col2:
st.subheader("Reagents (Optional)")
reagent_smiles_drawing = st_ketcher(st.session_state.reagents, key="ketcher_reagents")
if reagent_smiles_drawing != st.session_state.reagents:
st.session_state.reagents = reagent_smiles_drawing
st.rerun()
with input_tab2:
st.subheader("Enter SMILES Strings")
# Text inputs now also directly update the session state on change
st.text_input("Reactants SMILES", key="reactant_text", value=st.session_state.reactants, on_change=lambda: setattr(st.session_state, 'reactants', st.session_state.reactant_text))
st.text_input("Reagents SMILES", key="reagent_text", value=st.session_state.reagents, on_change=lambda: setattr(st.session_state, 'reagents', st.session_state.reagent_text))
# Display the current state clearly
st.info(f"**Current Reactants:** `{st.session_state.reactants}`")
st.info(f"**Current Reagents:** `{st.session_state.reagents or 'None'}`")
# Prediction Button
st.header("2. Generate Prediction")
if st.button("Predict Product", type="primary", use_container_width=True):
if not st.session_state.reactants or not st.session_state.reactants.strip():
st.error("Error: Reactants field cannot be empty. Please provide a molecule.")
else:
with st.spinner("Running prediction..."):
predictions = predict_product(
st.session_state.reactants,
st.session_state.reagents,
model,
tokenizer,
num_predictions
)
st.header("3. Predicted Products")
if not predictions:
st.warning("The model did not return any predictions.")
else:
for i, product_smiles in enumerate(predictions):
st.subheader(f"Top Prediction #{i + 1}")
st.code(product_smiles, language="smiles")
display_molecule(product_smiles, f"Predicted Product #{i + 1}")
elif not model or not tokenizer:
st.error("Application could not start because the model failed to load. Please check the error messages above.") |