File size: 2,762 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f01494b
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
import argparse
import selfies as sf
from tqdm import tqdm
from transformers import T5EncoderModel
from transformers import set_seed
from src.scripts.mytokenizers import Tokenizer
from src.improved_diffusion import gaussian_diffusion as gd
from src.improved_diffusion import dist_util, logger
from src.improved_diffusion.respace import SpacedDiffusion
from src.improved_diffusion.transformer_model import TransformerNetModel
from src.improved_diffusion.script_util import (
    model_and_diffusion_defaults,
    add_dict_to_argparser,
)
from src.scripts.mydatasets import Lang2molDataset_submission
import streamlit as st
import os


@st.cache_resource
def get_encoder():
    model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol")
    model.eval()
    return model


@st.cache_resource
def get_tokenizer():
    return Tokenizer()


@st.cache_resource
def get_model():
    model = TransformerNetModel(
        in_channels=32,
        model_channels=128,
        dropout=0.1,
        vocab_size=35073,
        hidden_size=1024,
        num_attention_heads=16,
        num_hidden_layers=12,
    )
    model.load_state_dict(
        torch.load(os.path.join("checkpoints", "PLAIN_ema_0.9999_360000.pt"))
    )
    model.eval()
    return model


@st.cache_resource
def get_diffusion():
    return SpacedDiffusion(
        use_timesteps=[i for i in range(0, 2000, 10)],
        betas=gd.get_named_beta_schedule("sqrt", 2000),
        model_mean_type=(gd.ModelMeanType.START_X),
        model_var_type=((gd.ModelVarType.FIXED_LARGE)),
        loss_type=gd.LossType.E2E_MSE,
        rescale_timesteps=True,
        model_arch="transformer",
        training_mode="e2e",
    )


tokenizer = get_tokenizer()
encoder = get_encoder()
model = get_model()
diffusion = get_diffusion()

sample_fn = diffusion.ddim_sample_loop

text_input = st.text_area("Enter molecule description")
output = tokenizer(
    text_input,
    max_length=256,
    truncation=True,
    padding="max_length",
    add_special_tokens=True,
    return_tensors="pt",
    return_attention_mask=True,
)
caption_state = encoder(
    input_ids=output["input_ids"],
    attention_mask=output["attention_mask"],
).last_hidden_state
caption_mask = output["attention_mask"]

outputs = sample_fn(
    model,
    (1, 256, 32),
    clip_denoised=False,
    denoised_fn=None,
    model_kwargs={},
    top_p=1.0,
    progress=True,
    caption=(caption_state, caption_mask),
)
logits = model.get_logits(torch.tensor(outputs))
cands = torch.topk(logits, k=1, dim=-1)
outputs = cands.indices
outputs = outputs.squeeze(-1)
outputs = tokenizer.decode(outputs)
result = sf.decoder(
    outputs[0].replace("<pad>", "").replace("</s>", "").replace("\t", "")
).replace("\t", "")

st.write(result)