Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -7,19 +7,15 @@ from types import SimpleNamespace
|
|
7 |
import pandas as pd
|
8 |
import streamlit as st
|
9 |
import torch
|
10 |
-
from torch.utils.data import DataLoader
|
11 |
-
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
12 |
|
13 |
# Local imports
|
14 |
-
sys.path.append(
|
15 |
-
os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
|
16 |
-
)
|
17 |
from generation_utils import (
|
18 |
ReactionT5Dataset,
|
19 |
decode_output,
|
20 |
save_multiple_predictions,
|
21 |
)
|
22 |
-
from
|
|
|
23 |
from utils import seed_everything
|
24 |
|
25 |
warnings.filterwarnings("ignore")
|
@@ -101,6 +97,10 @@ with st.sidebar:
|
|
101 |
model_help = "Recommended models for product prediction."
|
102 |
input_max_length_default = 400
|
103 |
output_max_length_default = 300
|
|
|
|
|
|
|
|
|
104 |
elif task == "retrosynthesis prediction":
|
105 |
model_options = [
|
106 |
"sagawa/ReactionT5v2-retrosynthesis",
|
@@ -109,6 +109,12 @@ with st.sidebar:
|
|
109 |
model_help = "Recommended models for retrosynthesis prediction."
|
110 |
input_max_length_default = 100
|
111 |
output_max_length_default = 400
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
else: # yield prediction
|
113 |
model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
|
114 |
model_help = "Default model for yield prediction."
|
@@ -143,8 +149,8 @@ with st.sidebar:
|
|
143 |
input_max_length = st.number_input(
|
144 |
"Input max length",
|
145 |
min_value=8,
|
146 |
-
max_value=
|
147 |
-
value=
|
148 |
step=8,
|
149 |
)
|
150 |
if task != "yield prediction":
|
|
|
7 |
import pandas as pd
|
8 |
import streamlit as st
|
9 |
import torch
|
|
|
|
|
10 |
|
11 |
# Local imports
|
|
|
|
|
|
|
12 |
from generation_utils import (
|
13 |
ReactionT5Dataset,
|
14 |
decode_output,
|
15 |
save_multiple_predictions,
|
16 |
)
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
19 |
from utils import seed_everything
|
20 |
|
21 |
warnings.filterwarnings("ignore")
|
|
|
97 |
model_help = "Recommended models for product prediction."
|
98 |
input_max_length_default = 400
|
99 |
output_max_length_default = 300
|
100 |
+
sys.path.append(
|
101 |
+
os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
|
102 |
+
)
|
103 |
+
from train import preprocess_df
|
104 |
elif task == "retrosynthesis prediction":
|
105 |
model_options = [
|
106 |
"sagawa/ReactionT5v2-retrosynthesis",
|
|
|
109 |
model_help = "Recommended models for retrosynthesis prediction."
|
110 |
input_max_length_default = 100
|
111 |
output_max_length_default = 400
|
112 |
+
sys.path.append(
|
113 |
+
os.path.abspath(
|
114 |
+
os.path.join(os.path.dirname(__file__), "task_retrosynthesis")
|
115 |
+
)
|
116 |
+
)
|
117 |
+
from train import preprocess_df
|
118 |
else: # yield prediction
|
119 |
model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
|
120 |
model_help = "Default model for yield prediction."
|
|
|
149 |
input_max_length = st.number_input(
|
150 |
"Input max length",
|
151 |
min_value=8,
|
152 |
+
max_value=1024,
|
153 |
+
value=input_max_length_default,
|
154 |
step=8,
|
155 |
)
|
156 |
if task != "yield prediction":
|