alpata commited on
Commit
4afe335
·
verified ·
1 Parent(s): 89588fc

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +8 -9
  2. app.py +121 -101
  3. requirements.txt +2 -1
README.md CHANGED
@@ -4,26 +4,25 @@ emoji: 🧪
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: streamlit
7
- sdk_version: 1.25.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # Chemical Reaction Predictor
13
 
14
  This application predicts the products of chemical reactions using a state-of-the-art T5-based model.
15
 
16
  ## How to Use the App
17
 
18
- 1. **Input Molecules**: You can either:
19
- * Use the **Chemical Drawing Tool** to draw the reactant and reagent molecules.
20
- * Go to the **SMILES Text Input** tab and paste the SMILES strings directly.
21
- 2. **Set Parameters**: In the sidebar, you can select the number of predictions you want to generate.
22
- 3. **Predict**: Click the "Predict Product" button to see the results.
23
- 4. **Load Examples**: Use the dropdown in the sidebar to load pre-defined example reactions to see how the app works.
24
 
25
  ## About the Model
26
 
27
- This application uses the `sagawa/ReactionT5v2-forward-USPTO_MIT` model, which has been fine-tuned for forward reaction prediction. It achieves a high accuracy of over 97% on the USPTO_MIT dataset.
28
 
29
  For more details about the model, please visit its page on the [Hugging Face Hub](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT).
 
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: streamlit
 
7
  app_file: app.py
8
  pinned: false
9
  ---
10
 
11
+ # 🧪 Chemical Reaction Predictor
12
 
13
  This application predicts the products of chemical reactions using a state-of-the-art T5-based model.
14
 
15
  ## How to Use the App
16
 
17
+ 1. **Input Molecules**: You have two options:
18
+ * Use the **✍️ Chemical Drawing Tool** to draw the reactant and reagent molecules.
19
+ * Switch to the **⌨️ SMILES Text Input** tab and paste the SMILES strings directly.
20
+ 2. **Load Examples (Optional)**: Use the dropdown in the sidebar to load pre-defined example reactions to see how the app works.
21
+ 3. **Set Parameters**: In the sidebar, you can select the number of predictions you want to generate.
22
+ 4. **Predict**: Click the "Predict Product" button to see the results.
23
 
24
  ## About the Model
25
 
26
+ This application uses the `sagawa/ReactionT5v2-forward-USPTO_MIT` model, which has been fine-tuned for forward reaction prediction.
27
 
28
  For more details about the model, please visit its page on the [Hugging Face Hub](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT).
app.py CHANGED
@@ -1,155 +1,175 @@
1
  import streamlit as st
2
  from transformers import T5ForConditionalGeneration, T5Tokenizer
3
- import torch
4
  from rdkit import Chem
5
  from rdkit.Chem import Draw
6
  from streamlit_ketcher import st_ketcher
 
7
 
8
- # Set page configuration
9
- st.set_page_config(page_title="Chemical Reaction Predictor", layout="wide")
 
 
 
 
 
10
 
11
- # Function to load the model and tokenizer
12
  @st.cache_resource
13
  def load_model():
14
  """Loads the T5 model and tokenizer from Hugging Face."""
15
  model_name = "sagawa/ReactionT5v2-forward-USPTO_MIT"
16
- model = T5ForConditionalGeneration.from_pretrained(model_name)
17
- tokenizer = T5Tokenizer.from_pretrained(model_name)
18
- return model, tokenizer
19
-
20
- # Function to predict the product
 
 
 
 
21
  def predict_product(reactants, reagents, model, tokenizer, num_predictions):
22
  """Predicts the reaction product using the T5 model."""
 
23
  input_text = f"reactants>{reactants}.reagents>{reagents}>products>"
 
24
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
25
 
26
  # Generate predictions
27
  outputs = model.generate(
28
  input_ids,
29
  max_length=512,
30
- num_beams=5,
31
  num_return_sequences=num_predictions,
32
- early_stopping=True
33
  )
34
 
35
- # Decode the predictions
36
  predictions = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
37
  return predictions
38
 
39
- # Function to display molecules
40
  def display_molecule(smiles_string, legend):
41
- """Displays a molecule from a SMILES string."""
42
  mol = Chem.MolFromSmiles(smiles_string)
43
  if mol:
44
- img = Draw.MolToImage(mol, size=(300, 300), legend=legend)
45
- st.image(img, use_column_width='auto')
 
 
 
46
  else:
47
- st.warning(f"Could not generate molecule for SMILES: {smiles_string}")
48
 
49
- # --- UI Layout ---
50
 
51
- # Header
52
- st.title("Chemical Reaction Predictor")
53
- st.markdown("Predict the products of chemical reactions using the `sagawa/ReactionT5v2-forward-USPTO_MIT` model.")
54
-
55
- # Load Model
56
- with st.spinner("Loading the prediction model..."):
57
- model, tokenizer = load_model()
58
 
59
- # Sidebar
60
  with st.sidebar:
 
 
61
  st.header("Controls and Information")
62
 
63
  # Example Reactions
64
- st.subheader("Example Reactions")
65
  example_reactions = {
 
66
  "Esterification": ("CCO.O=C(O)C", "C(C)(=O)O"),
67
  "Amide Formation": ("CCN.O=C(Cl)C", ""),
68
  "Suzuki Coupling": ("[B-](C1=CC=CC=C1)(F)(F)F.[K+].CC1=CC=C(Br)C=C1", "c1ccc(B(O)O)cc1"),
69
  }
70
- selected_example = st.selectbox("Choose an example:", list(example_reactions.keys()))
71
-
72
- if st.button("Load Example"):
73
- reactants_smiles_example, reagents_smiles_example = example_reactions[selected_example]
74
- st.session_state.reactants_smiles = reactants_smiles_example
75
- st.session_state.reagents_smiles = reagents_smiles_example
76
- st.session_state.ketcher_reactants = reactants_smiles_example
77
- st.session_state.ketcher_reagents = reagents_smiles_example
78
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  # Prediction Parameters
 
81
  st.subheader("Prediction Parameters")
82
- num_predictions = st.slider("Number of Predictions to Generate", 1, 5, 1)
 
83
 
84
  # About Section
85
  st.subheader("About")
86
  st.info(
87
- "This app uses the `sagawa/ReactionT5v2-forward-USPTO_MIT` model to predict chemical reaction products. "
88
- "Draw or input the SMILES strings for reactants and reagents, then click 'Predict Product'."
89
  )
90
- st.markdown("[Model on Hugging Face](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT)")
91
-
92
-
93
- # Main Content
94
- st.header("Input Reactants and Reagents")
95
-
96
- # Initialize session state for SMILES
97
- if 'reactants_smiles' not in st.session_state:
98
- st.session_state.reactants_smiles = ""
99
- if 'reagents_smiles' not in st.session_state:
100
- st.session_state.reagents_smiles = ""
101
-
102
- # Input Tabs
103
- input_tab1, input_tab2 = st.tabs(["Chemical Drawing Tool", "SMILES Text Input"])
104
-
105
- with input_tab1:
106
- st.subheader("Draw Molecules")
107
- col1, col2 = st.columns(2)
108
- with col1:
109
- st.write("Reactants")
110
- if 'ketcher_reactants' in st.session_state:
111
- reactant_smiles_from_drawing = st_ketcher(st.session_state.ketcher_reactants, key="ketcher_reactants")
112
- else:
113
- reactant_smiles_from_drawing = st_ketcher("", key="ketcher_reactants")
114
 
 
 
115
 
116
- with col2:
117
- st.write("Reagents")
118
- if 'ketcher_reagents' in st.session_state:
119
- reagent_smiles_from_drawing = st_ketcher(st.session_state.ketcher_reagents, key="ketcher_reagents")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  else:
121
- reagent_smiles_from_drawing = st_ketcher("", key="ketcher_reagents")
122
-
123
-
124
- if reactant_smiles_from_drawing != st.session_state.get('ketcher_reactants_val'):
125
- st.session_state.reactants_smiles = reactant_smiles_from_drawing
126
- st.session_state.ketcher_reactants_val = reactant_smiles_from_drawing
127
-
128
- if reagent_smiles_from_drawing != st.session_state.get('ketcher_reagents_val'):
129
- st.session_state.reagents_smiles = reagent_smiles_from_drawing
130
- st.session_state.ketcher_reagents_val = reagent_smiles_from_drawing
131
-
132
- with input_tab2:
133
- st.subheader("Enter SMILES Strings")
134
- reactants_smiles = st.text_input("Reactants SMILES", st.session_state.reactants_smiles, key="reactants_text_input")
135
- reagents_smiles = st.text_input("Reagents SMILES", st.session_state.reagents_smiles, key="reagents_text_input")
136
- st.session_state.reactants_smiles = reactants_smiles
137
- st.session_state.reagents_smiles = reagents_smiles
138
-
139
-
140
- # Prediction Button
141
- if st.button("Predict Product", type="primary"):
142
- reactants_to_use = st.session_state.reactants_smiles
143
- reagents_to_use = st.session_state.reagents_smiles
144
-
145
- if not reactants_to_use:
146
- st.error("Please provide reactants.")
147
- else:
148
- with st.spinner("Predicting reaction..."):
149
- predictions = predict_product(reactants_to_use, reagents_to_use, model, tokenizer, num_predictions)
150
-
151
- st.header("Predicted Products")
152
- for i, product_smiles in enumerate(predictions):
153
- st.subheader(f"Prediction #{i+1}")
154
- st.code(product_smiles, language="smiles")
155
- display_molecule(product_smiles, f"Predicted Product {i+1}")
 
1
  import streamlit as st
2
  from transformers import T5ForConditionalGeneration, T5Tokenizer
 
3
  from rdkit import Chem
4
  from rdkit.Chem import Draw
5
  from streamlit_ketcher import st_ketcher
6
+ import torch
7
 
8
+ # --- Page Configuration ---
9
+ st.set_page_config(
10
+ page_title="Chemical Reaction Predictor",
11
+ page_icon="🧪",
12
+ layout="wide",
13
+ initial_sidebar_state="expanded"
14
+ )
15
 
16
+ # --- Model Loading ---
17
  @st.cache_resource
18
  def load_model():
19
  """Loads the T5 model and tokenizer from Hugging Face."""
20
  model_name = "sagawa/ReactionT5v2-forward-USPTO_MIT"
21
+ try:
22
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
23
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
24
+ return model, tokenizer
25
+ except Exception as e:
26
+ st.error(f"Error loading model: {e}")
27
+ return None, None
28
+
29
+ # --- Core Functions ---
30
  def predict_product(reactants, reagents, model, tokenizer, num_predictions):
31
  """Predicts the reaction product using the T5 model."""
32
+ # Format the input string as required by the model
33
  input_text = f"reactants>{reactants}.reagents>{reagents}>products>"
34
+
35
  input_ids = tokenizer.encode(input_text, return_tensors='pt')
36
 
37
  # Generate predictions
38
  outputs = model.generate(
39
  input_ids,
40
  max_length=512,
41
+ num_beams=num_predictions * 2, # Generate more beams for better results
42
  num_return_sequences=num_predictions,
43
+ early_stopping=True,
44
  )
45
 
46
+ # Decode predictions
47
  predictions = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
48
  return predictions
49
 
 
50
  def display_molecule(smiles_string, legend):
51
+ """Generates and displays a molecule image from a SMILES string."""
52
  mol = Chem.MolFromSmiles(smiles_string)
53
  if mol:
54
+ try:
55
+ img = Draw.MolToImage(mol, size=(350, 350), legend=legend)
56
+ st.image(img, use_column_width='auto')
57
+ except Exception as e:
58
+ st.warning(f"Could not generate image for SMILES: {smiles_string}. Error: {e}")
59
  else:
60
+ st.warning(f"Invalid SMILES string provided: {smiles_string}")
61
 
 
62
 
63
+ # --- Initialize Session State ---
64
+ if 'reactants' not in st.session_state:
65
+ st.session_state.reactants = ""
66
+ if 'reagents' not in st.session_state:
67
+ st.session_state.reagents = ""
 
 
68
 
69
+ # --- Sidebar UI ---
70
  with st.sidebar:
71
+ st.title("🧪 Reaction Predictor")
72
+ st.markdown("---")
73
  st.header("Controls and Information")
74
 
75
  # Example Reactions
 
76
  example_reactions = {
77
+ "Select an example...": ("", ""),
78
  "Esterification": ("CCO.O=C(O)C", "C(C)(=O)O"),
79
  "Amide Formation": ("CCN.O=C(Cl)C", ""),
80
  "Suzuki Coupling": ("[B-](C1=CC=CC=C1)(F)(F)F.[K+].CC1=CC=C(Br)C=C1", "c1ccc(B(O)O)cc1"),
81
  }
 
 
 
 
 
 
 
 
82
 
83
+ def on_example_change():
84
+ example_key = st.session_state.example_select
85
+ reactants, reagents = example_reactions[example_key]
86
+ st.session_state.reactants = reactants
87
+ st.session_state.reagents = reagents
88
+
89
+ st.selectbox(
90
+ "Load an Example Reaction",
91
+ options=list(example_reactions.keys()),
92
+ key="example_select",
93
+ on_change=on_example_change
94
+ )
95
 
96
  # Prediction Parameters
97
+ st.markdown("---")
98
  st.subheader("Prediction Parameters")
99
+ num_predictions = st.slider("Number of Predictions", 1, 5, 1)
100
+ st.markdown("---")
101
 
102
  # About Section
103
  st.subheader("About")
104
  st.info(
105
+ "This app uses the sagawa/ReactionT5v2-forward-USPTO_MIT model to predict chemical reaction products. "
106
+ "Draw molecules or input SMILES strings, then click 'Predict Product'."
107
  )
108
+ st.markdown("[View Model on Hugging Face](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # --- Main Application UI ---
111
+ st.title("Chemical Reaction Predictor")
112
 
113
+ # Load Model
114
+ model, tokenizer = load_model()
115
+
116
+ if model and tokenizer:
117
+ st.success("Model loaded successfully!")
118
+
119
+ # Input Section
120
+ st.header("1. Input Reactants and Reagents")
121
+ input_tab1, input_tab2 = st.tabs(["✍️ Chemical Drawing Tool", "⌨️ SMILES Text Input"])
122
+
123
+ # Callback functions to update session state from text inputs
124
+ def on_reactant_text_change():
125
+ st.session_state.reactants = st.session_state.reactant_text
126
+
127
+ def on_reagent_text_change():
128
+ st.session_state.reagents = st.session_state.reagent_text
129
+
130
+ with input_tab1:
131
+ col1, col2 = st.columns(2)
132
+ with col1:
133
+ st.subheader("Reactants")
134
+ # The ketcher component's value is controlled by session state
135
+ reactant_smiles_drawing = st_ketcher(value=st.session_state.reactants, key="ketcher_reactants")
136
+ # If the drawing changes, update the session state
137
+ if reactant_smiles_drawing != st.session_state.reactants:
138
+ st.session_state.reactants = reactant_smiles_drawing
139
+ st.experimental_rerun()
140
+
141
+ with col2:
142
+ st.subheader("Reagents")
143
+ reagent_smiles_drawing = st_ketcher(value=st.session_state.reagents, key="ketcher_reagents")
144
+ if reagent_smiles_drawing != st.session_state.reagents:
145
+ st.session_state.reagents = reagent_smiles_drawing
146
+ st.experimental_rerun()
147
+
148
+ with input_tab2:
149
+ st.subheader("Enter SMILES Strings")
150
+ st.text_input("Reactants SMILES", key="reactant_text", on_change=on_reactant_text_change, value=st.session_state.reactants)
151
+ st.text_input("Reagents SMILES (optional)", key="reagent_text", on_change=on_reagent_text_change, value=st.session_state.reagents)
152
+
153
+ st.info(f"**Current Reactants:** `{st.session_state.reactants}`")
154
+ st.info(f"**Current Reagents:** `{st.session_state.reagents}`")
155
+
156
+ st.header("2. Generate Prediction")
157
+ if st.button("Predict Product", type="primary", use_container_width=True):
158
+ if not st.session_state.reactants:
159
+ st.error("Error: Reactants cannot be empty. Please draw a molecule or provide a SMILES string.")
160
  else:
161
+ with st.spinner("Running prediction... This may take a moment."):
162
+ predictions = predict_product(
163
+ st.session_state.reactants,
164
+ st.session_state.reagents,
165
+ model,
166
+ tokenizer,
167
+ num_predictions
168
+ )
169
+ st.header("3. Predicted Products")
170
+ for i, product_smiles in enumerate(predictions):
171
+ st.subheader(f"Top Prediction #{i+1}")
172
+ st.code(product_smiles, language="smiles")
173
+ display_molecule(product_smiles, f"Predicted Product #{i+1}")
174
+ else:
175
+ st.error("Application could not start. Please check the logs on Hugging Face Spaces.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -2,4 +2,5 @@ streamlit
2
  transformers
3
  torch
4
  rdkit
5
- streamlit-ketcher
 
 
2
  transformers
3
  torch
4
  rdkit
5
+ streamlit-ketcher
6
+ sentencepiece