alpata commited on
Commit
89588fc
·
verified ·
1 Parent(s): b9e2435

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +29 -0
  2. app.py +155 -0
  3. requirements.txt +5 -0
README.md ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Chemical Reaction Predictor
3
+ emoji: 🧪
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.25.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Chemical Reaction Predictor
13
+
14
+ This application predicts the products of chemical reactions using a state-of-the-art T5-based model.
15
+
16
+ ## How to Use the App
17
+
18
+ 1. **Input Molecules**: You can either:
19
+ * Use the **Chemical Drawing Tool** to draw the reactant and reagent molecules.
20
+ * Go to the **SMILES Text Input** tab and paste the SMILES strings directly.
21
+ 2. **Set Parameters**: In the sidebar, you can select the number of predictions you want to generate.
22
+ 3. **Predict**: Click the "Predict Product" button to see the results.
23
+ 4. **Load Examples**: Use the dropdown in the sidebar to load pre-defined example reactions to see how the app works.
24
+
25
+ ## About the Model
26
+
27
+ This application uses the `sagawa/ReactionT5v2-forward-USPTO_MIT` model, which has been fine-tuned for forward reaction prediction. It achieves a high accuracy of over 97% on the USPTO_MIT dataset.
28
+
29
+ For more details about the model, please visit its page on the [Hugging Face Hub](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT).
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import T5ForConditionalGeneration, T5Tokenizer
3
+ import torch
4
+ from rdkit import Chem
5
+ from rdkit.Chem import Draw
6
+ from streamlit_ketcher import st_ketcher
7
+
8
+ # Set page configuration
9
+ st.set_page_config(page_title="Chemical Reaction Predictor", layout="wide")
10
+
11
+ # Function to load the model and tokenizer
12
+ @st.cache_resource
13
+ def load_model():
14
+ """Loads the T5 model and tokenizer from Hugging Face."""
15
+ model_name = "sagawa/ReactionT5v2-forward-USPTO_MIT"
16
+ model = T5ForConditionalGeneration.from_pretrained(model_name)
17
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
18
+ return model, tokenizer
19
+
20
+ # Function to predict the product
21
+ def predict_product(reactants, reagents, model, tokenizer, num_predictions):
22
+ """Predicts the reaction product using the T5 model."""
23
+ input_text = f"reactants>{reactants}.reagents>{reagents}>products>"
24
+ input_ids = tokenizer.encode(input_text, return_tensors='pt')
25
+
26
+ # Generate predictions
27
+ outputs = model.generate(
28
+ input_ids,
29
+ max_length=512,
30
+ num_beams=5,
31
+ num_return_sequences=num_predictions,
32
+ early_stopping=True
33
+ )
34
+
35
+ # Decode the predictions
36
+ predictions = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
37
+ return predictions
38
+
39
+ # Function to display molecules
40
+ def display_molecule(smiles_string, legend):
41
+ """Displays a molecule from a SMILES string."""
42
+ mol = Chem.MolFromSmiles(smiles_string)
43
+ if mol:
44
+ img = Draw.MolToImage(mol, size=(300, 300), legend=legend)
45
+ st.image(img, use_column_width='auto')
46
+ else:
47
+ st.warning(f"Could not generate molecule for SMILES: {smiles_string}")
48
+
49
+ # --- UI Layout ---
50
+
51
+ # Header
52
+ st.title("Chemical Reaction Predictor")
53
+ st.markdown("Predict the products of chemical reactions using the `sagawa/ReactionT5v2-forward-USPTO_MIT` model.")
54
+
55
+ # Load Model
56
+ with st.spinner("Loading the prediction model..."):
57
+ model, tokenizer = load_model()
58
+
59
+ # Sidebar
60
+ with st.sidebar:
61
+ st.header("Controls and Information")
62
+
63
+ # Example Reactions
64
+ st.subheader("Example Reactions")
65
+ example_reactions = {
66
+ "Esterification": ("CCO.O=C(O)C", "C(C)(=O)O"),
67
+ "Amide Formation": ("CCN.O=C(Cl)C", ""),
68
+ "Suzuki Coupling": ("[B-](C1=CC=CC=C1)(F)(F)F.[K+].CC1=CC=C(Br)C=C1", "c1ccc(B(O)O)cc1"),
69
+ }
70
+ selected_example = st.selectbox("Choose an example:", list(example_reactions.keys()))
71
+
72
+ if st.button("Load Example"):
73
+ reactants_smiles_example, reagents_smiles_example = example_reactions[selected_example]
74
+ st.session_state.reactants_smiles = reactants_smiles_example
75
+ st.session_state.reagents_smiles = reagents_smiles_example
76
+ st.session_state.ketcher_reactants = reactants_smiles_example
77
+ st.session_state.ketcher_reagents = reagents_smiles_example
78
+
79
+
80
+ # Prediction Parameters
81
+ st.subheader("Prediction Parameters")
82
+ num_predictions = st.slider("Number of Predictions to Generate", 1, 5, 1)
83
+
84
+ # About Section
85
+ st.subheader("About")
86
+ st.info(
87
+ "This app uses the `sagawa/ReactionT5v2-forward-USPTO_MIT` model to predict chemical reaction products. "
88
+ "Draw or input the SMILES strings for reactants and reagents, then click 'Predict Product'."
89
+ )
90
+ st.markdown("[Model on Hugging Face](https://huggingface.co/sagawa/ReactionT5v2-forward-USPTO_MIT)")
91
+
92
+
93
+ # Main Content
94
+ st.header("Input Reactants and Reagents")
95
+
96
+ # Initialize session state for SMILES
97
+ if 'reactants_smiles' not in st.session_state:
98
+ st.session_state.reactants_smiles = ""
99
+ if 'reagents_smiles' not in st.session_state:
100
+ st.session_state.reagents_smiles = ""
101
+
102
+ # Input Tabs
103
+ input_tab1, input_tab2 = st.tabs(["Chemical Drawing Tool", "SMILES Text Input"])
104
+
105
+ with input_tab1:
106
+ st.subheader("Draw Molecules")
107
+ col1, col2 = st.columns(2)
108
+ with col1:
109
+ st.write("Reactants")
110
+ if 'ketcher_reactants' in st.session_state:
111
+ reactant_smiles_from_drawing = st_ketcher(st.session_state.ketcher_reactants, key="ketcher_reactants")
112
+ else:
113
+ reactant_smiles_from_drawing = st_ketcher("", key="ketcher_reactants")
114
+
115
+
116
+ with col2:
117
+ st.write("Reagents")
118
+ if 'ketcher_reagents' in st.session_state:
119
+ reagent_smiles_from_drawing = st_ketcher(st.session_state.ketcher_reagents, key="ketcher_reagents")
120
+ else:
121
+ reagent_smiles_from_drawing = st_ketcher("", key="ketcher_reagents")
122
+
123
+
124
+ if reactant_smiles_from_drawing != st.session_state.get('ketcher_reactants_val'):
125
+ st.session_state.reactants_smiles = reactant_smiles_from_drawing
126
+ st.session_state.ketcher_reactants_val = reactant_smiles_from_drawing
127
+
128
+ if reagent_smiles_from_drawing != st.session_state.get('ketcher_reagents_val'):
129
+ st.session_state.reagents_smiles = reagent_smiles_from_drawing
130
+ st.session_state.ketcher_reagents_val = reagent_smiles_from_drawing
131
+
132
+ with input_tab2:
133
+ st.subheader("Enter SMILES Strings")
134
+ reactants_smiles = st.text_input("Reactants SMILES", st.session_state.reactants_smiles, key="reactants_text_input")
135
+ reagents_smiles = st.text_input("Reagents SMILES", st.session_state.reagents_smiles, key="reagents_text_input")
136
+ st.session_state.reactants_smiles = reactants_smiles
137
+ st.session_state.reagents_smiles = reagents_smiles
138
+
139
+
140
+ # Prediction Button
141
+ if st.button("Predict Product", type="primary"):
142
+ reactants_to_use = st.session_state.reactants_smiles
143
+ reagents_to_use = st.session_state.reagents_smiles
144
+
145
+ if not reactants_to_use:
146
+ st.error("Please provide reactants.")
147
+ else:
148
+ with st.spinner("Predicting reaction..."):
149
+ predictions = predict_product(reactants_to_use, reagents_to_use, model, tokenizer, num_predictions)
150
+
151
+ st.header("Predicted Products")
152
+ for i, product_smiles in enumerate(predictions):
153
+ st.subheader(f"Prediction #{i+1}")
154
+ st.code(product_smiles, language="smiles")
155
+ display_molecule(product_smiles, f"Predicted Product {i+1}")
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch
4
+ rdkit
5
+ streamlit-ketcher