Spaces:
Running
Running
# βββ monkey-patch gradio_client so bool schemas donβt crash json_schema_to_python_type βββ | |
import gradio_client.utils as _gc_utils | |
# back up originals | |
_orig_get_type = _gc_utils.get_type | |
_orig_json2py = _gc_utils._json_schema_to_python_type | |
def _patched_get_type(schema): | |
# treat any boolean schema as if it were an empty dict | |
if isinstance(schema, bool): | |
schema = {} | |
return _orig_get_type(schema) | |
def _patched_json_schema_to_python_type(schema, defs=None): | |
# treat any boolean schema as if it were an empty dict | |
if isinstance(schema, bool): | |
schema = {} | |
return _orig_json2py(schema, defs) | |
_gc_utils.get_type = _patched_get_type | |
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type | |
# βββ now itβs safe to import Gradio and build your interface βββββββββββββββββββββββββββ | |
import gradio as gr | |
import os | |
import sys | |
import argparse | |
import tempfile | |
import shutil | |
import base64 | |
import io | |
import torch | |
import selfies | |
from rdkit import Chem | |
import matplotlib | |
matplotlib.use("Agg") | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
from typing import Optional | |
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel | |
from torch.utils.data import DataLoader | |
from Bio.PDB import PDBParser, MMCIFParser | |
from Bio.Data import IUPACData | |
from utils.drug_tokenizer import DrugTokenizer | |
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI | |
from utils.foldseek_util import get_struc_seq | |
# βββββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββ | |
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()} | |
three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"}) | |
def simple_seq_from_structure(path: str) -> str: | |
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True) | |
structure = parser.get_structure("P", path) | |
chains = list(structure.get_chains()) | |
if not chains: | |
return "" | |
chain = max(chains, key=lambda c: len(list(c.get_residues()))) | |
return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain) | |
def smiles_to_selfies(smiles: str) -> Optional[str]: | |
try: | |
mol = Chem.MolFromSmiles(smiles) | |
if mol is None: | |
return None | |
return selfies.encoder(smiles) | |
except: | |
return None | |
def parse_config(): | |
p = argparse.ArgumentParser() | |
p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2") | |
p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer") | |
p.add_argument("--agg_mode", type=str, default="mean_all_tok") | |
p.add_argument("--group_size", type=int, default=1) | |
p.add_argument("--fusion", default="CAN") | |
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") | |
p.add_argument("--save_path_prefix", default="save_model_ckp/") | |
p.add_argument("--dataset", default="Human") | |
return p.parse_args() | |
args = parse_config() | |
DEVICE = args.device | |
# βββββ Load models & tokenizers βββββββββββββββββββββββββββββββββ | |
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path) | |
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path) | |
drug_tokenizer = DrugTokenizer() | |
drug_model = AutoModel.from_pretrained(args.drug_encoder_path) | |
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE) | |
def collate_fn(batch): | |
query1, query2, scores = zip(*batch) | |
query_encodings1 = prot_tokenizer.batch_encode_plus( | |
list(query1), | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
query_encodings2 = drug_tokenizer.batch_encode_plus( | |
list(query2), | |
max_length=512, | |
padding="max_length", | |
truncation=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
scores = torch.tensor(list(scores)) | |
attention_mask1 = query_encodings1["attention_mask"].bool() | |
attention_mask2 = query_encodings2["attention_mask"].bool() | |
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores | |
def get_case_feature(model, loader): | |
model.eval() | |
with torch.no_grad(): | |
for p_ids, p_mask, d_ids, d_mask, _ in loader: | |
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE) | |
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE) | |
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask) | |
return [(p_emb.cpu(), d_emb.cpu(), | |
p_ids.cpu(), d_ids.cpu(), | |
p_mask.cpu(), d_mask.cpu(), None)] | |
# βββββββββββββββ visualisation βββββββββββββββββββββββββββββββββββββββββββ | |
def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str: | |
""" | |
Render a Protein β Drug cross-attention heat-map and, optionally, a | |
Top-30 protein-residue table for a chosen drug-token index. | |
The token index shown on the x-axis (and accepted via *drug_idx*) is **the | |
position of that token in the *original* drug sequence**, *after* the | |
tokeniser but *before* any pruning or truncation (1-based in the labels, | |
0-based for the function argument). | |
Returns | |
------- | |
html : str | |
Base64-embedded PNG heat-map (+ optional HTML table). | |
""" | |
model.eval() | |
with torch.no_grad(): | |
# ββ unpack single-case tensors βββββββββββββββββββββββββββββββββββββββββββ | |
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0] | |
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE) | |
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE) | |
# ββ forward pass: Protein β Drug attention (B, n_p, n_d) βββββββββββββββ | |
_, att_pd = model(p_emb, d_emb, p_mask, d_mask) | |
attn = att_pd.squeeze(0).cpu() # (n_p, n_d) | |
# ββ decode tokens (skip special symbols) ββββββββββββββββββββββββββββββββ | |
def clean_ids(ids, tokenizer): | |
toks = tokenizer.convert_ids_to_tokens(ids.tolist()) | |
return [t for t in toks if t not in tokenizer.all_special_tokens] | |
# ββ decode full sequences + record 1-based indices ββββββββββββββββββ | |
p_tokens_full = clean_ids(p_ids[0], prot_tokenizer) | |
p_indices_full = list(range(1, len(p_tokens_full) + 1)) | |
d_tokens_full = clean_ids(d_ids[0], drug_tokenizer) | |
d_indices_full = list(range(1, len(d_tokens_full) + 1)) | |
# ββ safety cut-off to match attn mat size βββββββββββββββββββββββββββββββ | |
p_tokens = p_tokens_full[: attn.size(0)] | |
p_indices_full = p_indices_full[: attn.size(0)] | |
d_tokens_full = d_tokens_full[: attn.size(1)] | |
d_indices_full = d_indices_full[: attn.size(1)] | |
attn = attn[: len(p_tokens_full), : len(d_tokens_full)] | |
orig_attn = attn.clone() | |
# ββ adaptive sparsity pruning βββββββββββββββββββββββββββββββββββββββββββ | |
thr = attn.max().item() * 0.05 | |
row_keep = (attn.max(dim=1).values > thr) | |
col_keep = (attn.max(dim=0).values > thr) | |
if row_keep.sum() < 3: | |
row_keep[:] = True | |
if col_keep.sum() < 3: | |
col_keep[:] = True | |
attn = attn[row_keep][:, col_keep] | |
p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep] | |
p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep] | |
d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep] | |
d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep] | |
# ββ cap column count at 150 for readability βββββββββββββββββββββββββββββ | |
if attn.size(1) > 150: | |
topc = torch.topk(attn.sum(0), k=150).indices | |
attn = attn[:, topc] | |
d_tokens = [d_tokens [i] for i in topc] | |
d_indices = [d_indices[i] for i in topc] | |
# ββ draw heat-map βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)] | |
y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)] | |
fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6β³ per column | |
fig_h = min(24, max(6, len(p_tokens) * 0.8)) | |
fig, ax = plt.subplots(figsize=(fig_w, fig_h)) | |
im = ax.imshow(attn.numpy(), aspect="auto", | |
cmap=cm.viridis, interpolation="nearest") | |
ax.set_title("Protein β Drug Attention", pad=8, fontsize=10) | |
ax.set_xticks(range(len(x_labels))) | |
ax.set_xticklabels(x_labels, rotation=90, fontsize=8, | |
ha="center", va="center") | |
ax.tick_params(axis="x", top=True, bottom=False, | |
labeltop=True, labelbottom=False, pad=27) | |
ax.set_yticks(range(len(y_labels))) | |
ax.set_yticklabels(y_labels, fontsize=7) | |
ax.tick_params(axis="y", top=True, bottom=False, | |
labeltop=True, labelbottom=False, | |
pad=10) | |
fig.colorbar(im, fraction=0.026, pad=0.01) | |
fig.tight_layout() | |
buf = io.BytesIO() | |
fig.savefig(buf, format="png", dpi=140) | |
plt.close(fig) | |
html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />' | |
# βββββββββββββββββββββ Top-30 tabel βββββββββββββββββββββ | |
table_html = "" | |
if drug_idx is not None and 0 <= drug_idx < orig_attn.size(1): | |
# map original 0-based drug_idx β current column position | |
if (drug_idx + 1) in d_indices: | |
col_pos = d_indices.index(drug_idx + 1) | |
elif 0 <= drug_idx < len(d_tokens): | |
col_pos = drug_idx | |
else: | |
col_pos = None | |
if col_pos is not None: | |
col_vec = attn[:, col_pos] | |
topk = torch.topk(col_vec, k=min(30, len(col_vec))).indices.tolist() | |
rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk))) | |
res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk) | |
pos_row = "".join(f"<td>{p_indices[i]}</td>"for i in topk) | |
drug_tok_text = d_tokens_full[col_pos] | |
orig_idx = d_indices_full[col_pos] | |
# 1) build the header row: leading βRankβ, then 1β¦30 | |
header_cells = ( | |
"<th style='border:1px solid #ccc; padding:6px; " | |
"background:#f7f7f7; text-align:center;'>Rank</th>" | |
+ "".join( | |
f"<th style='border:1px solid #ccc; padding:6px; " | |
f"background:#f7f7f7; text-align:center'>{r+1}</th>" | |
for r in range(len(topk)) | |
) | |
) | |
# 2) build the residue row: leading βResidueβ, then the residue tokens | |
residue_cells = ( | |
"<th style='border:1px solid #ccc; padding:6px; " | |
"background:#f7f7f7; text-align:center;'>Residue</th>" | |
+ "".join( | |
f"<td style='border:1px solid #ccc; padding:6px; " | |
f"text-align:center'>{p_tokens_full[i]}</td>" | |
for i in topk | |
) | |
) | |
# 3) build the position row: leading βPositionβ, then the residue positions | |
position_cells = ( | |
"<th style='border:1px solid #ccc; padding:6px; " | |
"background:#f7f7f7; text-align:center;'>Position</th>" | |
+ "".join( | |
f"<td style='border:1px solid #ccc; padding:6px; " | |
f"text-align:center'>{p_indices_full[i]}</td>" | |
for i in topk | |
) | |
) | |
# 4) assemble your table_html | |
table_html = ( | |
f"<h4 style='margin-bottom:12px'>" | |
f"Drug atom #{orig_idx} <code>{drug_tok_text}</code> β Top-30 Protein residues" | |
f"</h4>" | |
f"<table style='border-collapse:collapse; margin:0 auto 24px;'>" | |
f"<tr>{header_cells}</tr>" | |
f"<tr>{residue_cells}</tr>" | |
f"<tr>{position_cells}</tr>" | |
f"</table>" | |
) | |
buf_png = io.BytesIO() | |
fig.savefig(buf_png, format="png", dpi=140) | |
buf_png.seek(0) | |
buf_pdf = io.BytesIO() | |
fig.savefig(buf_pdf, format="pdf") | |
buf_pdf.seek(0) | |
plt.close(fig) | |
png_b64 = base64.b64encode(buf_png.getvalue()).decode() | |
pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode() | |
html_heat = ( | |
f"<div style='position: relative; width: 100%;'>" | |
# the PDF button, absolutely positioned | |
f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' " | |
"style='position: absolute; top: 12px; right: 12px; " | |
"background: var(--primary); color: #fff; " | |
"padding: 8px 16px; border-radius: 6px; " | |
"font-size: 0.9rem; font-weight: 500; " | |
"text-decoration: none;'>" | |
"Download PDF" | |
"</a>" | |
# the clickable heatβmap image | |
f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>" | |
f"<img src='data:image/png;base64,{png_b64}' " | |
"style='display: block; width: 100%; height: auto; cursor: zoom-in;'/>" | |
"</a>" | |
"</div>" | |
) | |
return table_html + html_heat | |
# βββββ Gradio Callbacks βββββββββββββββββββββββββββββββββββββββββ | |
ROOT = os.path.dirname(os.path.abspath(__file__)) | |
FOLDSEEK_BIN = os.path.join(ROOT, "bin", "foldseek") | |
def extract_sequence_cb(structure_file): | |
if structure_file is None or not os.path.exists(structure_file.name): | |
return "" | |
parsed = get_struc_seq(FOLDSEEK_BIN, structure_file.name, None, plddt_mask=False) | |
first_chain = next(iter(parsed)) | |
_, _, struct_seq = parsed[first_chain] | |
return struct_seq | |
def inference_cb(prot_seq, drug_seq, atom_idx): | |
if not prot_seq: | |
return "<p style='color:red'>Please extract or enter a protein sequence first.</p>" | |
if not drug_seq.strip(): | |
return "<p style='color:red'>Please enter a drug sequence.</p>" | |
if not drug_seq.strip().startswith("["): | |
conv = smiles_to_selfies(drug_seq.strip()) | |
if conv is None: | |
return "<p style='color:red'>SMILESβSELFIES conversion failed.</p>" | |
drug_seq = conv | |
loader = DataLoader([(prot_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn) | |
feats = get_case_feature(encoding, loader) | |
model = FusionDTI(446, 768, args).to(DEVICE) | |
ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt") | |
if os.path.isfile(ckpt): | |
model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) | |
return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None) | |
def clear_cb(): | |
return None, "", "", None, "" | |
# βββββ Gradio Interface Definition βββββββββββββββββββββββββββββββ | |
css = """ | |
:root { | |
--bg: #f3f4f6; | |
--card: #ffffff; | |
--border: #e5e7eb; | |
--primary: #6366f1; | |
--primary-dark: #4f46e5; | |
--text: #111827; | |
} | |
* { box-sizing: border-box; margin: 0; padding: 0; } | |
body { background: var(--bg); color: var(--text); font-family: Inter,system-ui,Arial,sans-serif; } | |
h1 { font-family: Poppins,Inter,sans-serif; font-weight: 600; font-size: 2rem; text-align: center; margin: 24px 0; } | |
button, .gr-button { font-family: Inter,sans-serif; font-weight: 600; } | |
#project-links { text-align: center; margin-bottom: 32px; } | |
#project-links .gr-button { margin: 0 8px; min-width: 160px; } | |
#project-links .gr-button:nth-child(1) { background: #10b981; } | |
#project-links .gr-button:nth-child(2) { background: #ef4444; } | |
#project-links .gr-button:nth-child(3) { background: #3b82f6; } | |
#project-links .gr-button:hover { opacity: 0.9; } | |
.link-btn{display:inline-block;margin:0 8px;padding:10px 20px;border-radius:8px; | |
color:white;font-weight:600;text-decoration:none;box-shadow:0 2px 6px rgba(0,0,0,0.12); | |
transition:all .2s ease-in-out;} | |
.link-btn:hover{opacity:.9;} | |
.link-btn.project{background:linear-gradient(to right,#10b981,#059669);} | |
.link-btn.arxiv {background:linear-gradient(to right,#ef4444,#dc2626);} | |
.link-btn.github {background:linear-gradient(to right,#3b82f6,#2563eb);} | |
/* make *all* gradio buttons a bit taller */ | |
.gr-button { min-height: 10px !important; } | |
/* now target just our two big action buttons */ | |
#extract-btn, #inference-btn { | |
width: 5px !important; | |
min-height: 36px !important; | |
margin-top: 12px !important; | |
} | |
/* and make clear button full width but shorter */ | |
#clear-btn { | |
width: 10px !important; | |
min-height: 36px !important; | |
margin-top: 12px !important; | |
} | |
#input-card label { | |
font-weight: 600 !important; /* make the text bold */ | |
color: var(--text) !important; /* use your standard text color */ | |
} | |
.card { | |
background: var(--card); | |
border: 1px solid var(--border); | |
border-radius: 12px; | |
padding: 24px; | |
max-width: 1000px; | |
margin: 0 auto 32px; | |
box-shadow: 0 2px 6px rgba(0,0,0,0.05); | |
} | |
#guidelines-card h2 { | |
font-size: 1.4rem; | |
margin-bottom: 16px; | |
text-align: center; | |
} | |
#guidelines-card ol { | |
margin-left: 20px; | |
line-height: 1.6; | |
font-size: 1rem; | |
} | |
#input-card .gr-row, #input-card .gr-cols { | |
gap: 16px; | |
} | |
#input-card .gr-button { | |
flex: 1; | |
} | |
#output-card { | |
padding-top: 0; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
# βββββββββββββ Title βββββββββββββ | |
gr.Markdown( | |
"<h1 style='text-align: center;'>Token-level Visualiser for Drug-Target Interaction</h1>" | |
) | |
# βββββββββββββ Project Links βββββββββββββ | |
gr.Markdown(""" | |
<div style="text-align:center;margin-bottom:32px;"> | |
<a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank">π Project Page</a> | |
<a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank">π ArXiv: 2406.01651</a> | |
<a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank">π» GitHub Repo</a> | |
</div> | |
""") | |
# βββββββββββββ Guidelines Card βββββββββββββ | |
gr.HTML( | |
""" | |
<div class="card" style="margin-bottom:24px"> | |
<h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for User</h2> | |
<ul style="font-size:1rem; margin-left:18px;line-height:1.55;list-style:decimal;"> | |
<li><strong>Convert protein structure into a structure-aware sequence:</strong> | |
Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware | |
sequence will be generated using | |
<a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>, | |
based on 3D structures from | |
<a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> or the | |
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li> | |
<li><strong>If you only have an amino acid sequence or a UniProt ID,</strong> | |
you must first visit the | |
<a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a> | |
or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold DB</a> | |
to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li> | |
<li><strong>Drug input supports both SELFIES and SMILES:</strong><br> | |
You can enter a SELFIES string directly, or paste a SMILES string. | |
SMILES will be automatically converted to SELFIES using | |
<a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>. | |
If conversion fails, a red error message will be displayed.</li> | |
<li>Optionally enter a <strong>1-based</strong> drug atom or substructure index | |
to highlight the Top-30 interacting protein residues.</li> | |
<li>After inference, you can use the | |
βDownload PDFβ link to export a high-resolution vector version.</li> | |
</ul> | |
</div> | |
""") | |
# βββββββββββββ Input Card βββββββββββββ | |
with gr.Column(elem_id="input-card", elem_classes="card"): | |
protein_seq = gr.Textbox( | |
label="Protein Structure-aware Sequence", | |
lines=3, | |
elem_id="protein-seq" | |
) | |
drug_seq = gr.Textbox( | |
label="Drug Sequence (SELFIES/SMILES)", | |
lines=3, | |
elem_id="drug-seq" | |
) | |
structure_file = gr.File( | |
label="Upload Protein Structure (.pdb/.cif)", | |
file_types=[".pdb", ".cif"], | |
elem_id="structure-file" | |
) | |
drug_idx = gr.Number( | |
label="Drug atom/substructure index (1-based)", | |
value=None, | |
precision=0, | |
elem_id="drug-idx" | |
) | |
# βββββββββββββ Action Buttons βββββββββββββ | |
with gr.Row(elem_id="action-buttons", equal_height=True): | |
btn_extract = gr.Button( | |
"Extract sequence", | |
variant="primary", | |
elem_id="extract-btn" | |
) | |
btn_infer = gr.Button( | |
"Inference", | |
variant="primary", | |
elem_id="inference-btn" | |
) | |
with gr.Row(): | |
clear_btn = gr.Button( | |
"Clear", | |
variant="secondary", | |
elem_classes="full-width", | |
elem_id="clear-btn" | |
) | |
# βββββββββββββ Output Visualization βββββββββββββ | |
output_html = gr.HTML(elem_id="result-html") | |
# βββββββββββββ Event Wiring βββββββββββββ | |
btn_extract.click( | |
fn=extract_sequence_cb, | |
inputs=[structure_file], | |
outputs=[protein_seq] | |
) | |
btn_infer.click( | |
fn=inference_cb, | |
inputs=[protein_seq, drug_seq, drug_idx], | |
outputs=[output_html] | |
) | |
clear_btn.click( | |
fn=lambda: ("", "", None, "", None), | |
inputs=[], | |
outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |