sagawa commited on
Commit
c8117fb
·
1 Parent(s): 3969a3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -30
app.py CHANGED
@@ -62,23 +62,38 @@ if st.button('predict'):
62
  min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
63
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
64
  output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
65
- scores = output['sequences_scores'].tolist()
66
- output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
67
- for ith, out in enumerate(output):
68
- mol = Chem.MolFromSmiles(out.rstrip('.'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  if type(mol) == rdkit.Chem.rdchem.Mol:
70
- output.append(out.rstrip('.'))
71
- scores.append(scores[ith])
72
- break
73
- if type(mol) == None:
74
- output.append(None)
75
- scores.append(None)
76
- output += scores
77
- output = [input_compound] + output
78
- outputs.append(output)
79
-
80
- output_df = pd.DataFrame(outputs, columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
81
-
82
  @st.cache
83
  def convert_df(df):
84
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
@@ -98,21 +113,34 @@ if st.button('predict'):
98
  min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
99
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
100
  output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
101
- scores = output['sequences_scores'].tolist()
102
- output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
103
- for ith, out in enumerate(output):
104
- mol = Chem.MolFromSmiles(out.rstrip('.'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  if type(mol) == rdkit.Chem.rdchem.Mol:
106
- output.append(out.rstrip('.'))
107
- scores.append(scores[ith])
108
- break
109
- if type(mol) == None:
110
- output.append(None)
111
- scores.append(None)
112
- output += scores
113
- output = [input_compound] + output
114
- try:
115
- output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
116
  st.table(output_df)
117
 
118
  @st.cache
 
62
  min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
63
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
64
  output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
65
+ if CFG.num_beams > 1:
66
+ scores = output['sequences_scores'].tolist()
67
+ output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
68
+ for ith, out in enumerate(output):
69
+ mol = Chem.MolFromSmiles(out.rstrip('.'))
70
+ if type(mol) == rdkit.Chem.rdchem.Mol:
71
+ output.append(out.rstrip('.'))
72
+ scores.append(scores[ith])
73
+ break
74
+ if type(mol) == None:
75
+ output.append(None)
76
+ scores.append(None)
77
+ output += scores
78
+ output = [input_compound] + output
79
+ outputs.append(output)
80
+
81
+ else:
82
+ output = [tokenizer.decode(output['sequences'][0], skip_special_tokens=True).replace('. ', '.').rstrip('.')]
83
+ mol = Chem.MolFromSmiles(output[0])
84
  if type(mol) == rdkit.Chem.rdchem.Mol:
85
+ output.append(output[0])
86
+ else:
87
+ output.append(None)
88
+ output = [input_compound] + output
89
+ outputs.append(output)
90
+
91
+ if CFG.num_beams > 1:
92
+ output_df = pd.DataFrame(outputs, columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
93
+ else:
94
+ output_df = pd.DataFrame(outputs, columns=['input', '0th', 'valid compound'])
95
+
96
+
97
  @st.cache
98
  def convert_df(df):
99
  # IMPORTANT: Cache the conversion to prevent computation on every rerun
 
113
  min_length = min(input_compound.find('CATALYST') - input_compound.find(':') - 10, 0)
114
  inp = tokenizer(input_compound, return_tensors='pt').to(device)
115
  output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=CFG.num_beams, num_return_sequences=CFG.num_return_sequences, return_dict_in_generate=True, output_scores=True)
116
+ if CFG.num_beams > 1:
117
+ scores = output['sequences_scores'].tolist()
118
+ output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output['sequences']]
119
+ for ith, out in enumerate(output):
120
+ mol = Chem.MolFromSmiles(out.rstrip('.'))
121
+ if type(mol) == rdkit.Chem.rdchem.Mol:
122
+ output.append(out.rstrip('.'))
123
+ scores.append(scores[ith])
124
+ break
125
+ if type(mol) == None:
126
+ output.append(None)
127
+ scores.append(None)
128
+ output += scores
129
+ output = [input_compound] + output
130
+
131
+ else:
132
+ output = [tokenizer.decode(output['sequences'][0], skip_special_tokens=True).replace('. ', '.').rstrip('.')]
133
+ mol = Chem.MolFromSmiles(output[0])
134
  if type(mol) == rdkit.Chem.rdchem.Mol:
135
+ output.append(output[0])
136
+ else:
137
+ output.append(None)
138
+
139
+
140
+ if CFG.num_beams > 1:
141
+ output_df = pd.DataFrame(np.array(output).reshape(1, -1), columns=['input'] + [f'{i}th' for i in range(CFG.num_beams)] + ['valid compound'] + [f'{i}th score' for i in range(CFG.num_beams)] + ['valid compound score'])
142
+ else:
143
+ output_df = pd.DataFrame(np.array([input_compound]+output).reshape(1, -1), columns=['input', '0th', 'valid compound'])
 
144
  st.table(output_df)
145
 
146
  @st.cache