alpata commited on
Commit
b1951fe
·
verified ·
1 Parent(s): 4afe335

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -49
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from transformers import T5ForConditionalGeneration, T5Tokenizer
3
  from rdkit import Chem
4
  from rdkit.Chem import Draw
5
  from streamlit_ketcher import st_ketcher
@@ -14,31 +14,47 @@ st.set_page_config(
14
  )
15
 
16
  # --- Model Loading ---
 
17
  @st.cache_resource
18
  def load_model():
19
- """Loads the T5 model and tokenizer from Hugging Face."""
 
 
 
20
  model_name = "sagawa/ReactionT5v2-forward-USPTO_MIT"
21
  try:
22
- model = T5ForConditionalGeneration.from_pretrained(model_name)
23
- tokenizer = T5Tokenizer.from_pretrained(model_name)
 
24
  return model, tokenizer
25
  except Exception as e:
26
- st.error(f"Error loading model: {e}")
 
 
 
 
 
 
 
27
  return None, None
28
 
29
  # --- Core Functions ---
30
  def predict_product(reactants, reagents, model, tokenizer, num_predictions):
31
  """Predicts the reaction product using the T5 model."""
32
  # Format the input string as required by the model
33
- input_text = f"reactants>{reactants}.reagents>{reagents}>products>"
 
 
 
 
34
 
35
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
36
 
37
- # Generate predictions
38
  outputs = model.generate(
39
  input_ids,
40
  max_length=512,
41
- num_beams=num_predictions * 2, # Generate more beams for better results
42
  num_return_sequences=num_predictions,
43
  early_stopping=True,
44
  )
@@ -49,20 +65,23 @@ def predict_product(reactants, reagents, model, tokenizer, num_predictions):
49
 
50
  def display_molecule(smiles_string, legend):
51
  """Generates and displays a molecule image from a SMILES string."""
 
 
 
52
  mol = Chem.MolFromSmiles(smiles_string)
53
  if mol:
54
  try:
55
- img = Draw.MolToImage(mol, size=(350, 350), legend=legend)
56
  st.image(img, use_column_width='auto')
57
  except Exception as e:
58
  st.warning(f"Could not generate image for SMILES: {smiles_string}. Error: {e}")
59
  else:
60
  st.warning(f"Invalid SMILES string provided: {smiles_string}")
61
 
62
-
63
  # --- Initialize Session State ---
 
64
  if 'reactants' not in st.session_state:
65
- st.session_state.reactants = ""
66
  if 'reagents' not in st.session_state:
67
  st.session_state.reagents = ""
68
 
@@ -74,13 +93,14 @@ with st.sidebar:
74
 
75
  # Example Reactions
76
  example_reactions = {
77
- "Select an example...": ("", ""),
78
- "Esterification": ("CCO.O=C(O)C", "C(C)(=O)O"),
79
  "Amide Formation": ("CCN.O=C(Cl)C", ""),
80
  "Suzuki Coupling": ("[B-](C1=CC=CC=C1)(F)(F)F.[K+].CC1=CC=C(Br)C=C1", "c1ccc(B(O)O)cc1"),
 
81
  }
82
 
83
- def on_example_change():
 
84
  example_key = st.session_state.example_select
85
  reactants, reagents = example_reactions[example_key]
86
  st.session_state.reactants = reactants
@@ -90,75 +110,70 @@ with st.sidebar:
90
  "Load an Example Reaction",
91
  options=list(example_reactions.keys()),
92
  key="example_select",
93
- on_change=on_example_change
94
  )
95
 
96
- # Prediction Parameters
97
  st.markdown("---")
98
  st.subheader("Prediction Parameters")
99
- num_predictions = st.slider("Number of Predictions", 1, 5, 1)
100
  st.markdown("---")
101
 
102
- # About Section
103
  st.subheader("About")
104
  st.info(
105
- "This app uses the sagawa/ReactionT5v2-forward-USPTO_MIT model to predict chemical reaction products. "
106
- "Draw molecules or input SMILES strings, then click 'Predict Product'."
107
  )
108
  st.markdown("[View Model on Hugging Face](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT)")
109
 
110
  # --- Main Application UI ---
111
  st.title("Chemical Reaction Predictor")
 
112
 
113
- # Load Model
114
- model, tokenizer = load_model()
 
115
 
 
116
  if model and tokenizer:
117
  st.success("Model loaded successfully!")
118
 
119
  # Input Section
120
- st.header("1. Input Reactants and Reagents")
121
  input_tab1, input_tab2 = st.tabs(["✍️ Chemical Drawing Tool", "⌨️ SMILES Text Input"])
122
 
123
- # Callback functions to update session state from text inputs
124
- def on_reactant_text_change():
125
- st.session_state.reactants = st.session_state.reactant_text
126
-
127
- def on_reagent_text_change():
128
- st.session_state.reagents = st.session_state.reagent_text
129
-
130
  with input_tab1:
131
  col1, col2 = st.columns(2)
132
  with col1:
133
  st.subheader("Reactants")
134
- # The ketcher component's value is controlled by session state
135
- reactant_smiles_drawing = st_ketcher(value=st.session_state.reactants, key="ketcher_reactants")
136
- # If the drawing changes, update the session state
137
  if reactant_smiles_drawing != st.session_state.reactants:
138
  st.session_state.reactants = reactant_smiles_drawing
139
- st.experimental_rerun()
140
 
141
  with col2:
142
- st.subheader("Reagents")
143
- reagent_smiles_drawing = st_ketcher(value=st.session_state.reagents, key="ketcher_reagents")
144
  if reagent_smiles_drawing != st.session_state.reagents:
145
  st.session_state.reagents = reagent_smiles_drawing
146
- st.experimental_rerun()
147
 
148
  with input_tab2:
149
  st.subheader("Enter SMILES Strings")
150
- st.text_input("Reactants SMILES", key="reactant_text", on_change=on_reactant_text_change, value=st.session_state.reactants)
151
- st.text_input("Reagents SMILES (optional)", key="reagent_text", on_change=on_reagent_text_change, value=st.session_state.reagents)
 
152
 
 
153
  st.info(f"**Current Reactants:** `{st.session_state.reactants}`")
154
- st.info(f"**Current Reagents:** `{st.session_state.reagents}`")
155
 
 
156
  st.header("2. Generate Prediction")
157
  if st.button("Predict Product", type="primary", use_container_width=True):
158
- if not st.session_state.reactants:
159
- st.error("Error: Reactants cannot be empty. Please draw a molecule or provide a SMILES string.")
160
  else:
161
- with st.spinner("Running prediction... This may take a moment."):
162
  predictions = predict_product(
163
  st.session_state.reactants,
164
  st.session_state.reagents,
@@ -167,9 +182,13 @@ if model and tokenizer:
167
  num_predictions
168
  )
169
  st.header("3. Predicted Products")
170
- for i, product_smiles in enumerate(predictions):
171
- st.subheader(f"Top Prediction #{i+1}")
172
- st.code(product_smiles, language="smiles")
173
- display_molecule(product_smiles, f"Predicted Product #{i+1}")
174
- else:
175
- st.error("Application could not start. Please check the logs on Hugging Face Spaces.")
 
 
 
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from rdkit import Chem
4
  from rdkit.Chem import Draw
5
  from streamlit_ketcher import st_ketcher
 
14
  )
15
 
16
  # --- Model Loading ---
17
+ # Use st.cache_resource to load the model only once
18
  @st.cache_resource
19
  def load_model():
20
+ """
21
+ Loads the T5 model and tokenizer from Hugging Face.
22
+ Uses AutoModel for better compatibility.
23
+ """
24
  model_name = "sagawa/ReactionT5v2-forward-USPTO_MIT"
25
  try:
26
+ # Use Auto* classes for robustness
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
29
  return model, tokenizer
30
  except Exception as e:
31
+ # Provide more detailed error information
32
+ st.error("An error occurred while loading the model.")
33
+ st.error(f"Error Type: {type(e).__name__}")
34
+ st.error(f"Error Details: {e}")
35
+ # Add a hint about potential memory issues on Hugging Face Spaces
36
+ st.info("Hint: Free tiers on Hugging Face Spaces have limited memory (RAM). "
37
+ "If the app fails to load the model, it might be due to an Out-of-Memory error. "
38
+ "Consider upgrading your Space for more resources.")
39
  return None, None
40
 
41
  # --- Core Functions ---
42
  def predict_product(reactants, reagents, model, tokenizer, num_predictions):
43
  """Predicts the reaction product using the T5 model."""
44
  # Format the input string as required by the model
45
+ # Handle the case where reagents might be empty
46
+ if reagents and reagents.strip():
47
+ input_text = f"reactants>{reactants}.reagents>{reagents}>products>"
48
+ else:
49
+ input_text = f"reactants>{reactants}>products>"
50
 
51
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
52
 
53
+ # Generate predictions using beam search
54
  outputs = model.generate(
55
  input_ids,
56
  max_length=512,
57
+ num_beams=num_predictions * 2, # Generate more beams for better diversity
58
  num_return_sequences=num_predictions,
59
  early_stopping=True,
60
  )
 
65
 
66
  def display_molecule(smiles_string, legend):
67
  """Generates and displays a molecule image from a SMILES string."""
68
+ if not smiles_string:
69
+ st.warning("Received an empty SMILES string.")
70
+ return
71
  mol = Chem.MolFromSmiles(smiles_string)
72
  if mol:
73
  try:
74
+ img = Draw.MolToImage(mol, size=(300, 300), legend=legend)
75
  st.image(img, use_column_width='auto')
76
  except Exception as e:
77
  st.warning(f"Could not generate image for SMILES: {smiles_string}. Error: {e}")
78
  else:
79
  st.warning(f"Invalid SMILES string provided: {smiles_string}")
80
 
 
81
  # --- Initialize Session State ---
82
+ # This ensures that the state is preserved across reruns
83
  if 'reactants' not in st.session_state:
84
+ st.session_state.reactants = "CCO.O=C(O)C" # Start with a default example
85
  if 'reagents' not in st.session_state:
86
  st.session_state.reagents = ""
87
 
 
93
 
94
  # Example Reactions
95
  example_reactions = {
96
+ "Esterification": ("CCO.O=C(O)C", ""),
 
97
  "Amide Formation": ("CCN.O=C(Cl)C", ""),
98
  "Suzuki Coupling": ("[B-](C1=CC=CC=C1)(F)(F)F.[K+].CC1=CC=C(Br)C=C1", "c1ccc(B(O)O)cc1"),
99
+ "Clear Inputs": ("", "")
100
  }
101
 
102
+ def load_example():
103
+ # Callback to load selected example into session state
104
  example_key = st.session_state.example_select
105
  reactants, reagents = example_reactions[example_key]
106
  st.session_state.reactants = reactants
 
110
  "Load an Example Reaction",
111
  options=list(example_reactions.keys()),
112
  key="example_select",
113
+ on_change=load_example
114
  )
115
 
 
116
  st.markdown("---")
117
  st.subheader("Prediction Parameters")
118
+ num_predictions = st.slider("Number of Predictions to Generate", 1, 5, 1, help="How many potential products should the model suggest?")
119
  st.markdown("---")
120
 
 
121
  st.subheader("About")
122
  st.info(
123
+ "This app uses the sagawa/ReactionT5v2-forward-USPTO_MIT model to predict chemical reaction products."
 
124
  )
125
  st.markdown("[View Model on Hugging Face](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT)")
126
 
127
  # --- Main Application UI ---
128
  st.title("Chemical Reaction Predictor")
129
+ st.markdown("A tool to predict chemical reactions using a state-of-the-art Transformer model.")
130
 
131
+ # --- Model Loading and Main Logic ---
132
+ with st.spinner("Loading the prediction model... This may take a moment on first startup."):
133
+ model, tokenizer = load_model()
134
 
135
+ # Only proceed if the model loaded successfully
136
  if model and tokenizer:
137
  st.success("Model loaded successfully!")
138
 
139
  # Input Section
140
+ st.header("1. Provide Reactants and Reagents")
141
  input_tab1, input_tab2 = st.tabs(["✍️ Chemical Drawing Tool", "⌨️ SMILES Text Input"])
142
 
 
 
 
 
 
 
 
143
  with input_tab1:
144
  col1, col2 = st.columns(2)
145
  with col1:
146
  st.subheader("Reactants")
147
+ # This component's value is now directly tied to the session state
148
+ reactant_smiles_drawing = st_ketcher(st.session_state.reactants, key="ketcher_reactants")
 
149
  if reactant_smiles_drawing != st.session_state.reactants:
150
  st.session_state.reactants = reactant_smiles_drawing
151
+ st.rerun() # Use the modern rerun command
152
 
153
  with col2:
154
+ st.subheader("Reagents (Optional)")
155
+ reagent_smiles_drawing = st_ketcher(st.session_state.reagents, key="ketcher_reagents")
156
  if reagent_smiles_drawing != st.session_state.reagents:
157
  st.session_state.reagents = reagent_smiles_drawing
158
+ st.rerun()
159
 
160
  with input_tab2:
161
  st.subheader("Enter SMILES Strings")
162
+ # Text inputs now also directly update the session state on change
163
+ st.text_input("Reactants SMILES", key="reactant_text", value=st.session_state.reactants, on_change=lambda: setattr(st.session_state, 'reactants', st.session_state.reactant_text))
164
+ st.text_input("Reagents SMILES", key="reagent_text", value=st.session_state.reagents, on_change=lambda: setattr(st.session_state, 'reagents', st.session_state.reagent_text))
165
 
166
+ # Display the current state clearly
167
  st.info(f"**Current Reactants:** `{st.session_state.reactants}`")
168
+ st.info(f"**Current Reagents:** `{st.session_state.reagents or 'None'}`")
169
 
170
+ # Prediction Button
171
  st.header("2. Generate Prediction")
172
  if st.button("Predict Product", type="primary", use_container_width=True):
173
+ if not st.session_state.reactants or not st.session_state.reactants.strip():
174
+ st.error("Error: Reactants field cannot be empty. Please provide a molecule.")
175
  else:
176
+ with st.spinner("Running prediction..."):
177
  predictions = predict_product(
178
  st.session_state.reactants,
179
  st.session_state.reagents,
 
182
  num_predictions
183
  )
184
  st.header("3. Predicted Products")
185
+ if not predictions:
186
+ st.warning("The model did not return any predictions.")
187
+ else:
188
+ for i, product_smiles in enumerate(predictions):
189
+ st.subheader(f"Top Prediction #{i + 1}")
190
+ st.code(product_smiles, language="smiles")
191
+ display_molecule(product_smiles, f"Predicted Product #{i + 1}")
192
+
193
+ elif not model or not tokenizer:
194
+ st.error("Application could not start because the model failed to load. Please check the error messages above.")