# ─── monkey-patch gradio_client so bool schemas don’t crash json_schema_to_python_type ───
import gradio_client.utils as _gc_utils
# back up originals
_orig_get_type = _gc_utils.get_type
_orig_json2py = _gc_utils._json_schema_to_python_type
def _patched_get_type(schema):
# treat any boolean schema as if it were an empty dict
if isinstance(schema, bool):
schema = {}
return _orig_get_type(schema)
def _patched_json_schema_to_python_type(schema, defs=None):
# treat any boolean schema as if it were an empty dict
if isinstance(schema, bool):
schema = {}
return _orig_json2py(schema, defs)
_gc_utils.get_type = _patched_get_type
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
# ─── now it’s safe to import Gradio and build your interface ───────────────────────────
import gradio as gr
import os
import sys
import argparse
import tempfile
import shutil
import base64
import io
import torch
import selfies
from rdkit import Chem
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib import cm
from typing import Optional
from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
from torch.utils.data import DataLoader
from Bio.PDB import PDBParser, MMCIFParser
from Bio.Data import IUPACData
from utils.drug_tokenizer import DrugTokenizer
from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
from utils.foldseek_util import get_struc_seq
# ───── Helpers ─────────────────────────────────────────────────
three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
def simple_seq_from_structure(path: str) -> str:
parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
structure = parser.get_structure("P", path)
chains = list(structure.get_chains())
if not chains:
return ""
chain = max(chains, key=lambda c: len(list(c.get_residues())))
return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
def smiles_to_selfies(smiles: str) -> Optional[str]:
try:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
return selfies.encoder(smiles)
except:
return None
def parse_config():
p = argparse.ArgumentParser()
p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
p.add_argument("--agg_mode", type=str, default="mean_all_tok")
p.add_argument("--group_size", type=int, default=1)
p.add_argument("--fusion", default="CAN")
p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
p.add_argument("--save_path_prefix", default="save_model_ckp/")
p.add_argument("--dataset", default="Human")
return p.parse_args()
args = parse_config()
DEVICE = args.device
# ───── Load models & tokenizers ─────────────────────────────────
prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
drug_tokenizer = DrugTokenizer()
drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
def collate_fn(batch):
query1, query2, scores = zip(*batch)
query_encodings1 = prot_tokenizer.batch_encode_plus(
list(query1),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
query_encodings2 = drug_tokenizer.batch_encode_plus(
list(query2),
max_length=512,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
scores = torch.tensor(list(scores))
attention_mask1 = query_encodings1["attention_mask"].bool()
attention_mask2 = query_encodings2["attention_mask"].bool()
return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
def get_case_feature(model, loader):
model.eval()
with torch.no_grad():
for p_ids, p_mask, d_ids, d_mask, _ in loader:
p_ids, p_mask = p_ids.to(DEVICE), p_mask.to(DEVICE)
d_ids, d_mask = d_ids.to(DEVICE), d_mask.to(DEVICE)
p_emb, d_emb = model.encoding(p_ids, p_mask, d_ids, d_mask)
return [(p_emb.cpu(), d_emb.cpu(),
p_ids.cpu(), d_ids.cpu(),
p_mask.cpu(), d_mask.cpu(), None)]
# ─────────────── visualisation ───────────────────────────────────────────
def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
"""
Render a Protein → Drug cross-attention heat-map and, optionally, a
Top-30 protein-residue table for a chosen drug-token index.
The token index shown on the x-axis (and accepted via *drug_idx*) is **the
position of that token in the *original* drug sequence**, *after* the
tokeniser but *before* any pruning or truncation (1-based in the labels,
0-based for the function argument).
Returns
-------
html : str
Base64-embedded PNG heat-map (+ optional HTML table).
"""
model.eval()
with torch.no_grad():
# ── unpack single-case tensors ───────────────────────────────────────────
p_emb, d_emb, p_ids, d_ids, p_mask, d_mask, _ = feats[0]
p_emb, d_emb = p_emb.to(DEVICE), d_emb.to(DEVICE)
p_mask, d_mask = p_mask.to(DEVICE), d_mask.to(DEVICE)
# ── forward pass: Protein → Drug attention (B, n_p, n_d) ───────────────
_, att_pd = model(p_emb, d_emb, p_mask, d_mask)
attn = att_pd.squeeze(0).cpu() # (n_p, n_d)
# ── decode tokens (skip special symbols) ────────────────────────────────
def clean_ids(ids, tokenizer):
toks = tokenizer.convert_ids_to_tokens(ids.tolist())
return [t for t in toks if t not in tokenizer.all_special_tokens]
# ── decode full sequences + record 1-based indices ──────────────────
p_tokens_full = clean_ids(p_ids[0], prot_tokenizer)
p_indices_full = list(range(1, len(p_tokens_full) + 1))
d_tokens_full = clean_ids(d_ids[0], drug_tokenizer)
d_indices_full = list(range(1, len(d_tokens_full) + 1))
# ── safety cut-off to match attn mat size ───────────────────────────────
p_tokens = p_tokens_full[: attn.size(0)]
p_indices_full = p_indices_full[: attn.size(0)]
d_tokens_full = d_tokens_full[: attn.size(1)]
d_indices_full = d_indices_full[: attn.size(1)]
attn = attn[: len(p_tokens_full), : len(d_tokens_full)]
orig_attn = attn.clone()
# ── adaptive sparsity pruning ───────────────────────────────────────────
thr = attn.max().item() * 0.05
row_keep = (attn.max(dim=1).values > thr)
col_keep = (attn.max(dim=0).values > thr)
if row_keep.sum() < 3:
row_keep[:] = True
if col_keep.sum() < 3:
col_keep[:] = True
attn = attn[row_keep][:, col_keep]
p_tokens = [tok for keep, tok in zip(row_keep, p_tokens) if keep]
p_indices = [idx for keep, idx in zip(row_keep, p_indices_full) if keep]
d_tokens = [tok for keep, tok in zip(col_keep, d_tokens_full) if keep]
d_indices = [idx for keep, idx in zip(col_keep, d_indices_full) if keep]
# ── cap column count at 150 for readability ─────────────────────────────
if attn.size(1) > 150:
topc = torch.topk(attn.sum(0), k=150).indices
attn = attn[:, topc]
d_tokens = [d_tokens [i] for i in topc]
d_indices = [d_indices[i] for i in topc]
# ── draw heat-map ───────────────────────────────────────────────────────
x_labels = [f"{idx}:{tok}" for idx, tok in zip(d_indices, d_tokens)]
y_labels = [f"{idx}:{tok}" for idx, tok in zip(p_indices, p_tokens)]
fig_w = min(22, max(8, len(x_labels) * 0.6)) # ~0.6″ per column
fig_h = min(24, max(6, len(p_tokens) * 0.8))
fig, ax = plt.subplots(figsize=(fig_w, fig_h))
im = ax.imshow(attn.numpy(), aspect="auto",
cmap=cm.viridis, interpolation="nearest")
ax.set_title("Protein → Drug Attention", pad=8, fontsize=10)
ax.set_xticks(range(len(x_labels)))
ax.set_xticklabels(x_labels, rotation=90, fontsize=8,
ha="center", va="center")
ax.tick_params(axis="x", top=True, bottom=False,
labeltop=True, labelbottom=False, pad=27)
ax.set_yticks(range(len(y_labels)))
ax.set_yticklabels(y_labels, fontsize=7)
ax.tick_params(axis="y", top=True, bottom=False,
labeltop=True, labelbottom=False,
pad=10)
fig.colorbar(im, fraction=0.026, pad=0.01)
fig.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=140)
plt.close(fig)
html = f''
# ───────────────────── Top-30 tabel ─────────────────────
table_html = ""
if drug_idx is not None and 0 <= drug_idx < orig_attn.size(1):
# map original 0-based drug_idx → current column position
if (drug_idx + 1) in d_indices:
col_pos = d_indices.index(drug_idx + 1)
elif 0 <= drug_idx < len(d_tokens):
col_pos = drug_idx
else:
col_pos = None
if col_pos is not None:
col_vec = attn[:, col_pos]
topk = torch.topk(col_vec, k=min(30, len(col_vec))).indices.tolist()
rank_hdr = "".join(f"
{drug_tok_text}
→ Top-30 Protein residues"
f"Please extract or enter a protein sequence first.
" if not drug_seq.strip(): return "Please enter a drug sequence.
" if not drug_seq.strip().startswith("["): conv = smiles_to_selfies(drug_seq.strip()) if conv is None: return "SMILES→SELFIES conversion failed.
" drug_seq = conv loader = DataLoader([(prot_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn) feats = get_case_feature(encoding, loader) model = FusionDTI(446, 768, args).to(DEVICE) ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt") if os.path.isfile(ckpt): model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None) def clear_cb(): return None, "", "", None, "" # ───── Gradio Interface Definition ─────────────────────────────── css = """ :root { --bg: #f3f4f6; --card: #ffffff; --border: #e5e7eb; --primary: #6366f1; --primary-dark: #4f46e5; --text: #111827; } * { box-sizing: border-box; margin: 0; padding: 0; } body { background: var(--bg); color: var(--text); font-family: Inter,system-ui,Arial,sans-serif; } h1 { font-family: Poppins,Inter,sans-serif; font-weight: 600; font-size: 2rem; text-align: center; margin: 24px 0; } button, .gr-button { font-family: Inter,sans-serif; font-weight: 600; } #project-links { text-align: center; margin-bottom: 32px; } #project-links .gr-button { margin: 0 8px; min-width: 160px; } #project-links .gr-button:nth-child(1) { background: #10b981; } #project-links .gr-button:nth-child(2) { background: #ef4444; } #project-links .gr-button:nth-child(3) { background: #3b82f6; } #project-links .gr-button:hover { opacity: 0.9; } .link-btn{display:inline-block;margin:0 8px;padding:10px 20px;border-radius:8px; color:white;font-weight:600;text-decoration:none;box-shadow:0 2px 6px rgba(0,0,0,0.12); transition:all .2s ease-in-out;} .link-btn:hover{opacity:.9;} .link-btn.project{background:linear-gradient(to right,#10b981,#059669);} .link-btn.arxiv {background:linear-gradient(to right,#ef4444,#dc2626);} .link-btn.github {background:linear-gradient(to right,#3b82f6,#2563eb);} /* make *all* gradio buttons a bit taller */ .gr-button { min-height: 10px !important; } /* now target just our two big action buttons */ #extract-btn, #inference-btn { width: 5px !important; min-height: 36px !important; margin-top: 12px !important; } /* and make clear button full width but shorter */ #clear-btn { width: 10px !important; min-height: 36px !important; margin-top: 12px !important; } #input-card label { font-weight: 600 !important; /* make the text bold */ color: var(--text) !important; /* use your standard text color */ } .card { background: var(--card); border: 1px solid var(--border); border-radius: 12px; padding: 24px; max-width: 1000px; margin: 0 auto 32px; box-shadow: 0 2px 6px rgba(0,0,0,0.05); } #guidelines-card h2 { font-size: 1.4rem; margin-bottom: 16px; text-align: center; } #guidelines-card ol { margin-left: 20px; line-height: 1.6; font-size: 1rem; } #input-card .gr-row, #input-card .gr-cols { gap: 16px; } #input-card .gr-button { flex: 1; } #output-card { padding-top: 0; } """ with gr.Blocks(css=css) as demo: # ───────────── Title ───────────── gr.Markdown( ".pdb
or .cif
file. A structure-aware
sequence will be generated using
Foldseek,
based on 3D structures from
AlphaFold DB or the
Protein Data Bank (PDB)..cif
or .pdb
file.