sagawa commited on
Commit
0da02e5
·
verified ·
1 Parent(s): edbdd3f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -7
app.py CHANGED
@@ -85,10 +85,40 @@ st.divider()
85
  with st.sidebar:
86
  st.header("Configuration")
87
 
88
- model_name_or_path = st.text_input(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  "Model",
90
- value="sagawa/ReactionT5v2-forward",
91
- help="Hugging Face model repo or a local path.",
 
92
  )
93
 
94
  num_beams = st.slider(
@@ -111,11 +141,20 @@ with st.sidebar:
111
 
112
  with st.expander("Advanced generation", expanded=False):
113
  input_max_length = st.number_input(
114
- "Input max length", min_value=8, max_value=1024, value=400, step=8
115
- )
116
- output_max_length = st.number_input(
117
- "Output max length", min_value=8, max_value=1024, value=300, step=8
 
118
  )
 
 
 
 
 
 
 
 
119
  output_min_length = st.number_input(
120
  "Output min length",
121
  min_value=-1,
 
85
  with st.sidebar:
86
  st.header("Configuration")
87
 
88
+ task = st.selectbox(
89
+ "Task",
90
+ options=["product prediction", "retrosynthesis prediction", "yield prediction"],
91
+ index=0,
92
+ help="Choose the task to run.",
93
+ )
94
+
95
+ # Model options tied to task
96
+ if task == "product prediction":
97
+ model_options = [
98
+ "sagawa/ReactionT5v2-forward",
99
+ "sagawa/ReactionT5v2-forward-USPTO_MIT",
100
+ ]
101
+ model_help = "Recommended models for product prediction."
102
+ input_max_length_default = 400
103
+ output_max_length_default = 300
104
+ elif task == "retrosynthesis prediction":
105
+ model_options = [
106
+ "sagawa/ReactionT5v2-retrosynthesis",
107
+ "sagawa/ReactionT5v2-retrosynthesis-USPTO_50k",
108
+ ]
109
+ model_help = "Recommended models for retrosynthesis prediction."
110
+ input_max_length_default = 100
111
+ output_max_length_default = 400
112
+ else: # yield prediction
113
+ model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
114
+ model_help = "Default model for yield prediction."
115
+ input_max_length_default = 400
116
+
117
+ model_name_or_path = st.selectbox(
118
  "Model",
119
+ options=model_options,
120
+ index=0,
121
+ help=model_help,
122
  )
123
 
124
  num_beams = st.slider(
 
141
 
142
  with st.expander("Advanced generation", expanded=False):
143
  input_max_length = st.number_input(
144
+ "Input max length",
145
+ min_value=8,
146
+ max_value=input_max_length_default,
147
+ value=400,
148
+ step=8,
149
  )
150
+ if task != "yield prediction":
151
+ output_max_length = st.number_input(
152
+ "Output max length",
153
+ min_value=8,
154
+ max_value=1024,
155
+ value=output_max_length_default,
156
+ step=8,
157
+ )
158
  output_min_length = st.number_input(
159
  "Output min length",
160
  min_value=-1,