Zhaohan Meng commited on
Commit
2881d33
·
verified ·
1 Parent(s): b6c28a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +339 -262
app.py CHANGED
@@ -1,56 +1,74 @@
1
- import os, sys, argparse, tempfile, shutil, base64, io
2
- from flask import Flask, request, render_template_string
3
- from werkzeug.utils import secure_filename
4
- from torch.utils.data import DataLoader
5
- import selfies
6
- from rdkit import Chem
 
7
 
8
  import torch
 
 
 
9
  import matplotlib
10
  matplotlib.use("Agg")
11
  import matplotlib.pyplot as plt
12
  from matplotlib import cm
13
  from typing import Optional
14
 
15
- from utils.drug_tokenizer import DrugTokenizer
16
  from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
 
 
 
 
 
17
  from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
18
  from utils.foldseek_util import get_struc_seq
19
 
20
- # ───── global paths / args ──────────────────────────────────────
21
- FOLDSEEK_BIN = shutil.which("foldseek")
22
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
- sys.path.append("..")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def parse_config():
26
  p = argparse.ArgumentParser()
27
- p.add_argument("-f")
28
  p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
29
  p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
30
- p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
31
  p.add_argument("--group_size", type=int, default=1)
32
- p.add_argument("--lr", type=float, default=1e-4)
33
  p.add_argument("--fusion", default="CAN")
34
  p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
35
  p.add_argument("--save_path_prefix", default="save_model_ckp/")
36
- p.add_argument("--dataset", default="BindingDB",
37
- help="Name of the dataset to use (e.g., 'BindingDB', 'Human', 'Biosnap')")
38
-
39
  return p.parse_args()
40
 
41
  args = parse_config()
42
  DEVICE = args.device
43
 
44
- # ───── tokenisers & encoders ────────────────────────────────────
45
  prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
46
  prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
47
-
48
- drug_tokenizer = DrugTokenizer() # SELFIES
49
  drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
 
50
 
51
- encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
52
-
53
- # ─── collate fn ────────────────────────────────────────────────
54
  def collate_fn(batch):
55
  query1, query2, scores = zip(*batch)
56
 
@@ -76,20 +94,8 @@ def collate_fn(batch):
76
  attention_mask2 = query_encodings2["attention_mask"].bool()
77
 
78
  return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
79
- # def collate_fn_batch_encoding(batch):
80
 
81
- def smiles_to_selfies(smiles: str) -> Optional[str]:
82
- try:
83
- mol = Chem.MolFromSmiles(smiles)
84
- if mol is None:
85
- return None
86
- selfies_str = selfies.encoder(smiles)
87
- return selfies_str
88
- except Exception:
89
- return None
90
 
91
-
92
- # ───── single-case embedding ───────────────────────────────────
93
  def get_case_feature(model, loader):
94
  model.eval()
95
  with torch.no_grad():
@@ -101,17 +107,12 @@ def get_case_feature(model, loader):
101
  p_ids.cpu(), d_ids.cpu(),
102
  p_mask.cpu(), d_mask.cpu(), None)]
103
 
104
- # ───── helper:过滤特殊 token ───────────────────────────────────
105
- def clean_tokens(ids, tokenizer):
106
- toks = tokenizer.convert_ids_to_tokens(ids.tolist())
107
- return [t for t in toks if t not in tokenizer.all_special_tokens]
108
-
109
- # ───── visualisation ───────────────────────────────────────────
110
-
111
  def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
112
  """
113
  Render a Protein → Drug cross-attention heat-map and, optionally, a
114
- Top-20 protein-residue table for a chosen drug-token index.
115
 
116
  The token index shown on the x-axis (and accepted via *drug_idx*) is **the
117
  position of that token in the *original* drug sequence**, *after* the
@@ -210,8 +211,8 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
210
  plt.close(fig)
211
  html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
212
 
213
- # ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
214
- table_html = "" # 先设空串,方便后面统一拼接
215
  if drug_idx is not None:
216
  # map original 0-based drug_idx → current column position
217
  if (drug_idx + 1) in d_indices:
@@ -223,7 +224,7 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
223
 
224
  if col_pos is not None:
225
  col_vec = attn[:, col_pos]
226
- topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
227
 
228
  rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
229
  res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
@@ -231,24 +232,58 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
231
 
232
  drug_tok_text = d_tokens[col_pos]
233
  orig_idx = d_indices[col_pos]
234
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  table_html = (
236
- f"<h4 style='margin-bottom:6px'>"
237
- f"Drug token #{orig_idx} <code>{drug_tok_text}</code> "
238
- f"→ Top-20 Protein residues</h4>"
239
- "<table class='tg' style='margin-bottom:8px'>"
240
- f"<tr><th>Rank</th>{rank_hdr}</tr>"
241
- f"<tr><td>Residue</td>{res_row}</tr>"
242
- f"<tr><td>Position</td>{pos_row}</tr>"
243
- "</table>")
244
-
245
- # ────────────────── 生成可放大 + 可下载的热图 ────────────────────
246
  buf_png = io.BytesIO()
247
- fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅)
248
  buf_png.seek(0)
249
 
250
  buf_pdf = io.BytesIO()
251
- fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量)
252
  buf_pdf.seek(0)
253
  plt.close(fig)
254
 
@@ -256,212 +291,254 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
256
  pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
257
 
258
  html_heat = (
259
- f"<a href='data:image/png;base64,{png_b64}' target='_blank' "
260
- f"title='Click to enlarge'>"
261
- f"<img src='data:image/png;base64,{png_b64}' "
262
- f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>"
263
- f"<div style='margin-top:6px'>"
264
- f"<a href='data:application/pdf;base64,{pdf_b64}' "
265
- f"download='attention_heatmap.pdf'>Download PDF</a></div>"
 
 
 
 
 
 
 
 
 
266
  )
267
 
268
- # ───────────────────────── 返回最终 HTML ─────────────────────────
269
  return table_html + html_heat
270
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
- # ───── Flask app ───────────────────────────────────────────────
273
- app = Flask(__name__)
274
-
275
- @app.route("/", methods=["GET", "POST"])
276
- def index():
277
- protein_seq = drug_seq = structure_seq = ""; result_html = None
278
- tmp_structure_path = ""; drug_idx = None
279
-
280
- if request.method == "POST":
281
- drug_idx_raw = request.form.get("drug_idx", "")
282
- drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None
283
-
284
- struct = request.files.get("structure_file")
285
- if struct and struct.filename:
286
- path = os.path.join(tempfile.gettempdir(), secure_filename(struct.filename))
287
- struct.save(path); tmp_structure_path = path
288
- else:
289
- tmp_structure_path = request.form.get("tmp_structure_path", "")
290
-
291
- if "clear" in request.form:
292
- protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = ""
293
-
294
- elif "confirm_structure" in request.form and tmp_structure_path:
295
- try:
296
- parsed_seqs = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, ["A"], plddt_mask=False)["A"]
297
- seq, foldseek_seq, structure_seq = parsed_seqs
298
- except Exception as e:
299
- result_html = (
300
- "<p style='color:red'><strong>Foldseek failed to extract sequence "
301
- f"from structure: {e}</strong></p>")
302
- structure_seq = ""
303
-
304
- protein_seq = structure_seq
305
- drug_input = request.form.get("drug_sequence", "")
306
- # Heuristically check if input is SMILES (not starting with [) and convert
307
- if not drug_input.strip().startswith("["):
308
- converted = smiles_to_selfies(drug_input.strip())
309
- if converted:
310
- drug_seq = converted
311
- else:
312
- drug_seq = ""
313
- result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>"
314
- else:
315
- drug_seq = drug_input
316
-
317
- elif "inference" in request.form:
318
- protein_seq = request.form.get("protein_sequence", "")
319
- drug_seq = request.form.get("drug_sequence", "")
320
- if protein_seq and drug_seq:
321
- loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1,
322
- collate_fn=collate_fn)
323
- feats = get_case_feature(encoding, loader)
324
- model = FusionDTI(446, 768, args).to(DEVICE)
325
- ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}",
326
- "best_model.ckpt")
327
- if os.path.isfile(ckpt):
328
- model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
329
- result_html = visualize_attention(model, feats, drug_idx)
330
-
331
- return render_template_string(
332
- # ───────────── HTML (原 UI + 新输入框) ─────────────
333
- """
334
- <!doctype html>
335
- <html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title>
336
- <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet">
337
-
338
- <style>
339
- :root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;}
340
- *{box-sizing:border-box;margin:0;padding:0}
341
- body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;}
342
- h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;}
343
- .card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border);
344
- border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;}
345
- label{font-weight:500;margin-bottom:6px;display:block}
346
- textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px;
347
- border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;}
348
- textarea{min-height:90px}
349
- .btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500;
350
- font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;}
351
- .btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)}
352
- .btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280}
353
- .grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr}
354
- .vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;}
355
- pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px}
356
-
357
- /* ── tidy table for Top-20 list ─────────────────────────────── */
358
- table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem}
359
- table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left}
360
- table.tg th{background:var(--bg);font-weight:600}
361
- </style>
362
- </head>
363
- <body>
364
- <h1> Token-level Visualiser for Drug-Target Interaction</h1>
365
-
366
- <!-- ───────────── Project Links (larger + spaced) ───────────── -->
367
- <div style="margin-top:24px; text-align:center;">
368
- <a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank"
369
- style="display:inline-block;margin:8px 18px;padding:10px 20px;
370
- background:linear-gradient(to right,#10b981,#059669);color:white;
371
- font-weight:600;border-radius:8px;font-size:0.9rem;
372
- font-family:Inter,sans-serif;text-decoration:none;
373
- box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
374
- onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
375
- 🌐 Project Page
376
- </a>
377
-
378
- <a href="https://arxiv.org/abs/2406.01651" target="_blank"
379
- style="display:inline-block;margin:8px 18px;padding:10px 20px;
380
- background:linear-gradient(to right,#ef4444,#dc2626);color:white;
381
- font-weight:600;border-radius:8px;font-size:0.9rem;
382
- font-family:Inter,sans-serif;text-decoration:none;
383
- box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
384
- onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
385
- 📄 ArXiv: 2406.01651
386
- </a>
387
-
388
- <a href="https://github.com/ZhaohanM/FusionDTI" target="_blank"
389
- style="display:inline-block;margin:8px 18px;padding:10px 20px;
390
- background:linear-gradient(to right,#3b82f6,#2563eb);color:white;
391
- font-weight:600;border-radius:8px;font-size:0.9rem;
392
- font-family:Inter,sans-serif;text-decoration:none;
393
- box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
394
- onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
395
- 💻 GitHub Repo
396
- </a>
397
- </div>
398
-
399
- <!-- ───────────── Guidelines for Use ───────────── -->
400
- <div class="card" style="margin-bottom:24px">
401
- <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2>
402
- <ul style="margin-left:18px;line-height:1.55;list-style:decimal;">
403
- <li><strong>Convert protein structure into a structure-aware sequence:</strong>
404
- Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
405
- sequence will be generated using
406
- <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
407
- based on 3D structures from
408
- <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
409
- <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
410
-
411
- <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
412
- you must first visit the
413
- <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
414
- or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
415
- to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
416
-
417
- <li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
418
- You can enter a SELFIES string directly, or paste a SMILES string.
419
- SMILES will be automatically converted to SELFIES using
420
- <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
421
- If conversion fails, a red error message will be displayed.</li>
422
-
423
- <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
424
- to highlight the Top-10 interacting protein residues.</li>
425
-
426
- <li>After inference, you can use the
427
- “Download PDF” link to export a high-resolution vector version.</li>
428
- </ul>
429
- </div>
430
-
431
- <div class="card">
432
- <form method="POST" enctype="multipart/form-data" class="grid">
433
-
434
- <div><label>Protein Structure (.pdb / .cif)</label>
435
- <input type="file" name="structure_file">
436
- <input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div>
437
-
438
- <div><label>Protein Sequence</label>
439
- <textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div>
440
-
441
- <div><label>Drug Sequence (SELFIES/SMILES)</label>
442
- <textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div>
443
-
444
- <label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label>
445
- <input type="number" name="drug_idx" min="1" style="width:120px">
446
-
447
- <div class="grid grid-2">
448
- <button class="btn btn-primary" type="submit" name="confirm_structure">Confirm</button>
449
- <button class="btn btn-primary" type="submit" name="inference">Inference</button>
450
- </div>
451
- <button class="btn btn-neutral" style="width:100%" type="submit" name="clear">Clear</button>
452
- </form>
453
-
454
- {% if structure_seq %}
455
- <div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div>
456
- {% endif %}
457
- {% if result_html %}
458
- <div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div>
459
- {% endif %}
460
- </div></body></html>
461
- """,
462
- protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq,
463
- result_html=result_html, tmp_structure_path=tmp_structure_path)
464
-
465
- # ───── run ─────────────────────────────────────────────────────
466
  if __name__ == "__main__":
467
- app.run(debug=True, host="0.0.0.0", port=7860)
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import tempfile
5
+ import shutil
6
+ import base64
7
+ import io
8
 
9
  import torch
10
+ import selfies
11
+ from rdkit import Chem
12
+ import gradio as gr
13
  import matplotlib
14
  matplotlib.use("Agg")
15
  import matplotlib.pyplot as plt
16
  from matplotlib import cm
17
  from typing import Optional
18
 
 
19
  from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
20
+ from torch.utils.data import DataLoader
21
+ from Bio.PDB import PDBParser, MMCIFParser
22
+ from Bio.Data import IUPACData
23
+
24
+ from utils.drug_tokenizer import DrugTokenizer
25
  from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
26
  from utils.foldseek_util import get_struc_seq
27
 
28
+ # ───── Helpers ─────────────────────────────────────────────────
29
+
30
+ three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
31
+ three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
32
+ def simple_seq_from_structure(path: str) -> str:
33
+ parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
34
+ structure = parser.get_structure("P", path)
35
+ chains = list(structure.get_chains())
36
+ if not chains:
37
+ return ""
38
+ chain = max(chains, key=lambda c: len(list(c.get_residues())))
39
+ return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
40
+
41
+ def smiles_to_selfies(smiles: str) -> Optional[str]:
42
+ try:
43
+ mol = Chem.MolFromSmiles(smiles)
44
+ if mol is None:
45
+ return None
46
+ return selfies.encoder(smiles)
47
+ except:
48
+ return None
49
 
50
  def parse_config():
51
  p = argparse.ArgumentParser()
 
52
  p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
53
  p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
54
+ p.add_argument("--agg_mode", type=str, default="mean_all_tok")
55
  p.add_argument("--group_size", type=int, default=1)
 
56
  p.add_argument("--fusion", default="CAN")
57
  p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
58
  p.add_argument("--save_path_prefix", default="save_model_ckp/")
59
+ p.add_argument("--dataset", default="Human")
 
 
60
  return p.parse_args()
61
 
62
  args = parse_config()
63
  DEVICE = args.device
64
 
65
+ # ───── Load models & tokenizers ─────────────────────────────────
66
  prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
67
  prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
68
+ drug_tokenizer = DrugTokenizer()
 
69
  drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
70
+ encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
71
 
 
 
 
72
  def collate_fn(batch):
73
  query1, query2, scores = zip(*batch)
74
 
 
94
  attention_mask2 = query_encodings2["attention_mask"].bool()
95
 
96
  return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
 
97
 
 
 
 
 
 
 
 
 
 
98
 
 
 
99
  def get_case_feature(model, loader):
100
  model.eval()
101
  with torch.no_grad():
 
107
  p_ids.cpu(), d_ids.cpu(),
108
  p_mask.cpu(), d_mask.cpu(), None)]
109
 
110
+
111
+ # ─────────────── visualisation ───────────────────────────────────────────
 
 
 
 
 
112
  def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
113
  """
114
  Render a Protein → Drug cross-attention heat-map and, optionally, a
115
+ Top-30 protein-residue table for a chosen drug-token index.
116
 
117
  The token index shown on the x-axis (and accepted via *drug_idx*) is **the
118
  position of that token in the *original* drug sequence**, *after* the
 
211
  plt.close(fig)
212
  html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
213
 
214
+ # ───────────────────── Top-30 tabel ─────────────────────
215
+ table_html = ""
216
  if drug_idx is not None:
217
  # map original 0-based drug_idx → current column position
218
  if (drug_idx + 1) in d_indices:
 
224
 
225
  if col_pos is not None:
226
  col_vec = attn[:, col_pos]
227
+ topk = torch.topk(col_vec, k=min(30, len(col_vec))).indices.tolist()
228
 
229
  rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
230
  res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
 
232
 
233
  drug_tok_text = d_tokens[col_pos]
234
  orig_idx = d_indices[col_pos]
235
+
236
+ # 1) build the header row: leading “Rank”, then 1…30
237
+ header_cells = (
238
+ "<th style='border:1px solid #ccc; padding:6px; "
239
+ "background:#f7f7f7; text-align:center;'>Rank</th>"
240
+ + "".join(
241
+ f"<th style='border:1px solid #ccc; padding:6px; "
242
+ f"background:#f7f7f7; text-align:center'>{r+1}</th>"
243
+ for r in range(len(topk))
244
+ )
245
+ )
246
+
247
+ # 2) build the residue row: leading “Residue”, then the residue tokens
248
+ residue_cells = (
249
+ "<th style='border:1px solid #ccc; padding:6px; "
250
+ "background:#f7f7f7; text-align:center;'>Residue</th>"
251
+ + "".join(
252
+ f"<td style='border:1px solid #ccc; padding:6px; "
253
+ f"text-align:center'>{p_tokens[i]}</td>"
254
+ for i in topk
255
+ )
256
+ )
257
+
258
+ # 3) build the position row: leading “Position”, then the residue positions
259
+ position_cells = (
260
+ "<th style='border:1px solid #ccc; padding:6px; "
261
+ "background:#f7f7f7; text-align:center;'>Position</th>"
262
+ + "".join(
263
+ f"<td style='border:1px solid #ccc; padding:6px; "
264
+ f"text-align:center'>{p_indices[i]}</td>"
265
+ for i in topk
266
+ )
267
+ )
268
+
269
+ # 4) assemble your table_html
270
  table_html = (
271
+ f"<h4 style='margin-bottom:12px'>"
272
+ f"Drug atom #{orig_idx} <code>{drug_tok_text}</code> → Top-30 Protein residues"
273
+ f"</h4>"
274
+ f"<table style='border-collapse:collapse; margin:0 auto 24px;'>"
275
+ f"<tr>{header_cells}</tr>"
276
+ f"<tr>{residue_cells}</tr>"
277
+ f"<tr>{position_cells}</tr>"
278
+ f"</table>"
279
+ )
280
+
281
  buf_png = io.BytesIO()
282
+ fig.savefig(buf_png, format="png", dpi=140)
283
  buf_png.seek(0)
284
 
285
  buf_pdf = io.BytesIO()
286
+ fig.savefig(buf_pdf, format="pdf")
287
  buf_pdf.seek(0)
288
  plt.close(fig)
289
 
 
291
  pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
292
 
293
  html_heat = (
294
+ f"<div style='position: relative; width: 100%;'>"
295
+ # the PDF button, absolutely positioned
296
+ f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' "
297
+ "style='position: absolute; top: 12px; right: 12px; "
298
+ "background: var(--primary); color: #fff; "
299
+ "padding: 8px 16px; border-radius: 6px; "
300
+ "font-size: 0.9rem; font-weight: 500; "
301
+ "text-decoration: none;'>"
302
+ "Download PDF"
303
+ "</a>"
304
+ # the clickable heat‐map image
305
+ f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>"
306
+ f"<img src='data:image/png;base64,{png_b64}' "
307
+ "style='display: block; width: 100%; height: auto; cursor: zoom-in;'/>"
308
+ "</a>"
309
+ "</div>"
310
  )
311
 
 
312
  return table_html + html_heat
313
 
314
+ # ───── Gradio Callbacks ─────────────────────────────────────────
315
+
316
+ ROOT = os.path.dirname(os.path.abspath(__file__))
317
+ FOLDSEEK_BIN = os.path.join(ROOT, "bin", "foldseek")
318
+
319
+ def extract_sequence_cb(structure_file):
320
+ if structure_file is None or not os.path.exists(structure_file.name):
321
+ return ""
322
+ parsed = get_struc_seq(FOLDSEEK_BIN, structure_file.name, None, plddt_mask=False)
323
+ first_chain = next(iter(parsed))
324
+ _, _, struct_seq = parsed[first_chain]
325
+ return struct_seq
326
+
327
+ def inference_cb(prot_seq, drug_seq, atom_idx):
328
+ if not prot_seq:
329
+ return "<p style='color:red'>Please extract or enter a protein sequence first.</p>"
330
+ if not drug_seq.strip():
331
+ return "<p style='color:red'>Please enter a drug sequence.</p>"
332
+ if not drug_seq.strip().startswith("["):
333
+ conv = smiles_to_selfies(drug_seq.strip())
334
+ if conv is None:
335
+ return "<p style='color:red'>SMILES→SELFIES conversion failed.</p>"
336
+ drug_seq = conv
337
+ loader = DataLoader([(prot_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn)
338
+ feats = get_case_feature(encoding, loader)
339
+ model = FusionDTI(446, 768, args).to(DEVICE)
340
+ ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt")
341
+ if os.path.isfile(ckpt):
342
+ model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
343
+ return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None)
344
+
345
+ def clear_cb():
346
+ return None, "", "", None, ""
347
+
348
+ # ───── Gradio Interface Definition ───────────────────────────────
349
+
350
+ css = """
351
+ :root {
352
+ --bg: #f3f4f6;
353
+ --card: #ffffff;
354
+ --border: #e5e7eb;
355
+ --primary: #6366f1;
356
+ --primary-dark: #4f46e5;
357
+ --text: #111827;
358
+ }
359
+ * { box-sizing: border-box; margin: 0; padding: 0; }
360
+ body { background: var(--bg); color: var(--text); font-family: Inter,system-ui,Arial,sans-serif; }
361
+ h1 { font-family: Poppins,Inter,sans-serif; font-weight: 600; font-size: 2rem; text-align: center; margin: 24px 0; }
362
+ button, .gr-button { font-family: Inter,sans-serif; font-weight: 600; }
363
+ #project-links { text-align: center; margin-bottom: 32px; }
364
+ #project-links .gr-button { margin: 0 8px; min-width: 160px; }
365
+ #project-links .gr-button:nth-child(1) { background: #10b981; }
366
+ #project-links .gr-button:nth-child(2) { background: #ef4444; }
367
+ #project-links .gr-button:nth-child(3) { background: #3b82f6; }
368
+ #project-links .gr-button:hover { opacity: 0.9; }
369
+ .link-btn{display:inline-block;margin:0 8px;padding:10px 20px;border-radius:8px;
370
+ color:white;font-weight:600;text-decoration:none;box-shadow:0 2px 6px rgba(0,0,0,0.12);
371
+ transition:all .2s ease-in-out;}
372
+ .link-btn:hover{opacity:.9;}
373
+ .link-btn.project{background:linear-gradient(to right,#10b981,#059669);}
374
+ .link-btn.arxiv {background:linear-gradient(to right,#ef4444,#dc2626);}
375
+ .link-btn.github {background:linear-gradient(to right,#3b82f6,#2563eb);}
376
+
377
+ /* make *all* gradio buttons a bit taller */
378
+ .gr-button { min-height: 10px !important; }
379
+
380
+ /* now target just our two big action buttons */
381
+ #extract-btn, #inference-btn {
382
+ width: 5px !important;
383
+ min-height: 36px !important;
384
+ margin-top: 12px !important;
385
+ }
386
+
387
+ /* and make clear button full width but shorter */
388
+ #clear-btn {
389
+ width: 10px !important;
390
+ min-height: 36px !important;
391
+ margin-top: 12px !important;
392
+ }
393
+
394
+ #input-card label {
395
+ font-weight: 600 !important; /* make the text bold */
396
+ color: var(--text) !important; /* use your standard text color */
397
+ }
398
+
399
+ .card {
400
+ background: var(--card);
401
+ border: 1px solid var(--border);
402
+ border-radius: 12px;
403
+ padding: 24px;
404
+ max-width: 1000px;
405
+ margin: 0 auto 32px;
406
+ box-shadow: 0 2px 6px rgba(0,0,0,0.05);
407
+ }
408
+
409
+ #guidelines-card h2 {
410
+ font-size: 1.4rem;
411
+ margin-bottom: 16px;
412
+ text-align: center;
413
+ }
414
+ #guidelines-card ol {
415
+ margin-left: 20px;
416
+ line-height: 1.6;
417
+ font-size: 1rem;
418
+ }
419
+ #input-card .gr-row, #input-card .gr-cols {
420
+ gap: 16px;
421
+ }
422
+ #input-card .gr-button {
423
+ flex: 1;
424
+ }
425
+ #output-card {
426
+ padding-top: 0;
427
+ }
428
+ """
429
+
430
+ with gr.Blocks(css=css) as demo:
431
+ # ───────────── Title ─────────────
432
+ gr.Markdown("<h1>Token-level Visualiser for Drug-Target Interaction</h1>")
433
+
434
+ # ───────────── Project Links ─────────────
435
+ gr.Markdown("""
436
+ <div style="text-align:center;margin-bottom:32px;">
437
+ <a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank">🌐 Project Page</a>
438
+ <a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank">📄 ArXiv: 2406.01651</a>
439
+ <a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank">💻 GitHub Repo</a>
440
+ </div>
441
+ """)
442
+ # ───────────── Guidelines Card ─────────────
443
+
444
+ gr.HTML(
445
+ """
446
+ <div class="card" style="margin-bottom:24px">
447
+ <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for User</h2>
448
+ <ul style="font-size:1rem; margin-left:18px;line-height:1.55;list-style:decimal;">
449
+ <li><strong>Convert protein structure into a structure-aware sequence:</strong>
450
+ Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
451
+ sequence will be generated using
452
+ <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
453
+ based on 3D structures from
454
+ <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
455
+ <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
456
+
457
+ <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
458
+ you must first visit the
459
+ <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
460
+ or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
461
+ to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
462
+
463
+ <li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
464
+ You can enter a SELFIES string directly, or paste a SMILES string.
465
+ SMILES will be automatically converted to SELFIES using
466
+ <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
467
+ If conversion fails, a red error message will be displayed.</li>
468
+
469
+ <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
470
+ to highlight the Top-30 interacting protein residues.</li>
471
+
472
+ <li>After inference, you can use the
473
+ “Download PDF” link to export a high-resolution vector version.</li>
474
+ </ul>
475
+ </div>
476
+ """)
477
+
478
+ # ───────────── Input Card ─────────────
479
+ with gr.Column(elem_id="input-card", elem_classes="card"):
480
+
481
+ protein_seq = gr.Textbox(
482
+ label="Protein Structure-aware Sequence",
483
+ lines=3,
484
+ elem_id="protein-seq"
485
+ )
486
+ drug_seq = gr.Textbox(
487
+ label="Drug Sequence (SELFIES/SMILES)",
488
+ lines=3,
489
+ elem_id="drug-seq"
490
+ )
491
+ structure_file = gr.File(
492
+ label="Upload Protein Structure (.pdb/.cif)",
493
+ file_types=[".pdb", ".cif"],
494
+ elem_id="structure-file"
495
+ )
496
+ drug_idx = gr.Number(
497
+ label="Drug atom/substructure index (1-based)",
498
+ value=None,
499
+ precision=0,
500
+ elem_id="drug-idx"
501
+ )
502
+
503
+ # ───────────── Action Buttons ─────────────
504
+ with gr.Row(elem_id="action-buttons", equal_height=True):
505
+ btn_extract = gr.Button(
506
+ "Extract sequence",
507
+ variant="primary",
508
+ elem_id="extract-btn"
509
+ )
510
+ btn_infer = gr.Button(
511
+ "Inference",
512
+ variant="primary",
513
+ elem_id="inference-btn"
514
+ )
515
+ with gr.Row():
516
+ clear_btn = gr.Button(
517
+ "Clear",
518
+ variant="secondary",
519
+ elem_classes="full-width",
520
+ elem_id="clear-btn"
521
+ )
522
+
523
+ # ───────────── Output Visualization ─────────────
524
+ output_html = gr.HTML(elem_id="result-html")
525
+
526
+ # ───────────── Event Wiring ─────────────
527
+ btn_extract.click(
528
+ fn=extract_sequence_cb,
529
+ inputs=[structure_file],
530
+ outputs=[protein_seq]
531
+ )
532
+ btn_infer.click(
533
+ fn=inference_cb,
534
+ inputs=[protein_seq, drug_seq, drug_idx],
535
+ outputs=[output_html]
536
+ )
537
+ clear_btn.click(
538
+ fn=lambda: ("", "", None, None, None, None),
539
+ inputs=[],
540
+ outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file]
541
+ )
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  if __name__ == "__main__":
544
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)