sagawa commited on
Commit
06428dd
·
verified ·
1 Parent(s): 0da02e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -8
app.py CHANGED
@@ -7,19 +7,15 @@ from types import SimpleNamespace
7
  import pandas as pd
8
  import streamlit as st
9
  import torch
10
- from torch.utils.data import DataLoader
11
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
 
13
  # Local imports
14
- sys.path.append(
15
- os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
16
- )
17
  from generation_utils import (
18
  ReactionT5Dataset,
19
  decode_output,
20
  save_multiple_predictions,
21
  )
22
- from train import preprocess_df
 
23
  from utils import seed_everything
24
 
25
  warnings.filterwarnings("ignore")
@@ -101,6 +97,10 @@ with st.sidebar:
101
  model_help = "Recommended models for product prediction."
102
  input_max_length_default = 400
103
  output_max_length_default = 300
 
 
 
 
104
  elif task == "retrosynthesis prediction":
105
  model_options = [
106
  "sagawa/ReactionT5v2-retrosynthesis",
@@ -109,6 +109,12 @@ with st.sidebar:
109
  model_help = "Recommended models for retrosynthesis prediction."
110
  input_max_length_default = 100
111
  output_max_length_default = 400
 
 
 
 
 
 
112
  else: # yield prediction
113
  model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
114
  model_help = "Default model for yield prediction."
@@ -143,8 +149,8 @@ with st.sidebar:
143
  input_max_length = st.number_input(
144
  "Input max length",
145
  min_value=8,
146
- max_value=input_max_length_default,
147
- value=400,
148
  step=8,
149
  )
150
  if task != "yield prediction":
 
7
  import pandas as pd
8
  import streamlit as st
9
  import torch
 
 
10
 
11
  # Local imports
 
 
 
12
  from generation_utils import (
13
  ReactionT5Dataset,
14
  decode_output,
15
  save_multiple_predictions,
16
  )
17
+ from torch.utils.data import DataLoader
18
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
19
  from utils import seed_everything
20
 
21
  warnings.filterwarnings("ignore")
 
97
  model_help = "Recommended models for product prediction."
98
  input_max_length_default = 400
99
  output_max_length_default = 300
100
+ sys.path.append(
101
+ os.path.abspath(os.path.join(os.path.dirname(__file__), "task_forward"))
102
+ )
103
+ from train import preprocess_df
104
  elif task == "retrosynthesis prediction":
105
  model_options = [
106
  "sagawa/ReactionT5v2-retrosynthesis",
 
109
  model_help = "Recommended models for retrosynthesis prediction."
110
  input_max_length_default = 100
111
  output_max_length_default = 400
112
+ sys.path.append(
113
+ os.path.abspath(
114
+ os.path.join(os.path.dirname(__file__), "task_retrosynthesis")
115
+ )
116
+ )
117
+ from train import preprocess_df
118
  else: # yield prediction
119
  model_options = ["sagawa/ReactionT5v2-yield"] # default as requested
120
  model_help = "Default model for yield prediction."
 
149
  input_max_length = st.number_input(
150
  "Input max length",
151
  min_value=8,
152
+ max_value=1024,
153
+ value=input_max_length_default,
154
  step=8,
155
  )
156
  if task != "yield prediction":