Gla-AI4BioMed-Lab commited on
Commit
b7ce511
·
verified ·
1 Parent(s): 44cf989

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -277
app.py CHANGED
@@ -1,36 +1,82 @@
1
- import os, sys, argparse, tempfile, shutil, base64, io
2
- from flask import Flask, request, render_template_string
3
- from werkzeug.utils import secure_filename
4
- from torch.utils.data import DataLoader
5
- import selfies
6
- from rdkit import Chem
7
- import app as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  import torch
 
 
10
  import matplotlib
11
  matplotlib.use("Agg")
12
  import matplotlib.pyplot as plt
13
  from matplotlib import cm
14
  from typing import Optional
15
 
16
- from utils.drug_tokenizer import DrugTokenizer
17
  from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
 
 
 
 
 
18
  from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
19
  from utils.foldseek_util import get_struc_seq
20
 
21
- # ───── global paths / args ──────────────────────────────────────
22
- FOLDSEEK_BIN = shutil.which("foldseek")
23
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
24
- sys.path.append("..")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def parse_config():
27
  p = argparse.ArgumentParser()
28
- p.add_argument("-f")
29
  p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
30
  p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
31
- p.add_argument("--agg_mode", default="mean_all_tok", type=str, help="{cls|mean|mean_all_tok}")
32
  p.add_argument("--group_size", type=int, default=1)
33
- p.add_argument("--lr", type=float, default=1e-4)
34
  p.add_argument("--fusion", default="CAN")
35
  p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
36
  p.add_argument("--save_path_prefix", default="save_model_ckp/")
@@ -40,16 +86,13 @@ def parse_config():
40
  args = parse_config()
41
  DEVICE = args.device
42
 
43
- # ───── tokenisers & encoders ────────────────────────────────────
44
  prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
45
  prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
46
-
47
- drug_tokenizer = DrugTokenizer() # SELFIES
48
  drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
 
49
 
50
- encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
51
-
52
- # ─── collate fn ────────────────────────────────────────────────
53
  def collate_fn(batch):
54
  query1, query2, scores = zip(*batch)
55
 
@@ -75,20 +118,8 @@ def collate_fn(batch):
75
  attention_mask2 = query_encodings2["attention_mask"].bool()
76
 
77
  return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
78
- # def collate_fn_batch_encoding(batch):
79
-
80
- def smiles_to_selfies(smiles: str) -> Optional[str]:
81
- try:
82
- mol = Chem.MolFromSmiles(smiles)
83
- if mol is None:
84
- return None
85
- selfies_str = selfies.encoder(smiles)
86
- return selfies_str
87
- except Exception:
88
- return None
89
 
90
 
91
- # ───── single-case embedding ───────────────────────────────────
92
  def get_case_feature(model, loader):
93
  model.eval()
94
  with torch.no_grad():
@@ -100,17 +131,12 @@ def get_case_feature(model, loader):
100
  p_ids.cpu(), d_ids.cpu(),
101
  p_mask.cpu(), d_mask.cpu(), None)]
102
 
103
- # ───── helper:过滤特殊 token ───────────────────────────────────
104
- def clean_tokens(ids, tokenizer):
105
- toks = tokenizer.convert_ids_to_tokens(ids.tolist())
106
- return [t for t in toks if t not in tokenizer.all_special_tokens]
107
-
108
- # ───── visualisation ───────────────────────────────────────────
109
-
110
  def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
111
  """
112
  Render a Protein → Drug cross-attention heat-map and, optionally, a
113
- Top-20 protein-residue table for a chosen drug-token index.
114
 
115
  The token index shown on the x-axis (and accepted via *drug_idx*) is **the
116
  position of that token in the *original* drug sequence**, *after* the
@@ -209,8 +235,8 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
209
  plt.close(fig)
210
  html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
211
 
212
- # ───────────────────── 生成 Top-20 表(若需要) ─────────────────────
213
- table_html = "" # 先设空串,方便后面统一拼接
214
  if drug_idx is not None:
215
  # map original 0-based drug_idx → current column position
216
  if (drug_idx + 1) in d_indices:
@@ -222,7 +248,7 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
222
 
223
  if col_pos is not None:
224
  col_vec = attn[:, col_pos]
225
- topk = torch.topk(col_vec, k=min(20, len(col_vec))).indices.tolist()
226
 
227
  rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
228
  res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
@@ -230,24 +256,58 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
230
 
231
  drug_tok_text = d_tokens[col_pos]
232
  orig_idx = d_indices[col_pos]
233
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  table_html = (
235
- f"<h4 style='margin-bottom:6px'>"
236
- f"Drug token #{orig_idx} <code>{drug_tok_text}</code> "
237
- f"→ Top-20 Protein residues</h4>"
238
- "<table class='tg' style='margin-bottom:8px'>"
239
- f"<tr><th>Rank</th>{rank_hdr}</tr>"
240
- f"<tr><td>Residue</td>{res_row}</tr>"
241
- f"<tr><td>Position</td>{pos_row}</tr>"
242
- "</table>")
243
-
244
- # ────────────────── 生成可放大 + 可下载的热图 ────────────────────
245
  buf_png = io.BytesIO()
246
- fig.savefig(buf_png, format="png", dpi=140) # 预览(光栅)
247
  buf_png.seek(0)
248
 
249
  buf_pdf = io.BytesIO()
250
- fig.savefig(buf_pdf, format="pdf") # 高清下载(矢量)
251
  buf_pdf.seek(0)
252
  plt.close(fig)
253
 
@@ -255,228 +315,254 @@ def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
255
  pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
256
 
257
  html_heat = (
258
- f"<a href='data:image/png;base64,{png_b64}' target='_blank' "
259
- f"title='Click to enlarge'>"
260
- f"<img src='data:image/png;base64,{png_b64}' "
261
- f"style='max-width:100%;height:auto;cursor:zoom-in' /></a>"
262
- f"<div style='margin-top:6px'>"
263
- f"<a href='data:application/pdf;base64,{pdf_b64}' "
264
- f"download='attention_heatmap.pdf'>Download PDF</a></div>"
 
 
 
 
 
 
 
 
 
265
  )
266
 
267
- # ───────────────────────── 返回最终 HTML ─────────────────────────
268
  return table_html + html_heat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- def inference(protein_seq, drug_seq, drug_idx, structure_file):
271
- # —— 这一块换成 Gradio 取文件路径 ——
272
- if structure_file is not None and os.path.exists(structure_file.name):
273
- tmp_structure_path = structure_file.name
274
- else:
275
- return "<p style='color:red'>请先上传一个有效的 .pdb .cif 文件。</p>"
276
-
277
- # 调用 foldseek
278
- try:
279
- parsed = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, ["A"], plddt_mask=False)
280
- chain = next(iter(parsed))
281
- protein_seq = parsed[chain][2]
282
- except Exception as e:
283
- return f"<p style='color:red'>Foldseek 提取失败:{e}</p>"
284
-
285
- # ───── Flask app ───────────────────────────────────────────────
286
- app = Flask(__name__)
287
-
288
- @app.route("/", methods=["GET", "POST"])
289
- def index():
290
- protein_seq = drug_seq = structure_seq = ""; result_html = None
291
- tmp_structure_path = ""; drug_idx = None
292
-
293
- if request.method == "POST":
294
- drug_idx_raw = request.form.get("drug_idx", "")
295
- drug_idx = int(drug_idx_raw)-1 if drug_idx_raw.isdigit() else None
296
-
297
- struct = request.files.get("structure_file")
298
- if struct and struct.filename:
299
- tmp_dir = tempfile.mkdtemp(prefix="foldseek_")
300
- safe_name = secure_filename(struct.filename)
301
- tmp_structure_path = os.path.join(tmp_dir, safe_name)
302
- struct.save(tmp_structure_path)
303
- else:
304
- tmp_structure_path = request.form.get("tmp_structure_path", "")
305
-
306
- if "clear" in request.form:
307
- protein_seq = drug_seq = structure_seq = ""; tmp_structure_path = ""
308
-
309
- elif "confirm_structure" in request.form and tmp_structure_path:
310
- try:
311
- parsed_seqs = get_struc_seq(FOLDSEEK_BIN, tmp_structure_path, ["A"], plddt_mask=False)["A"]
312
- seq, foldseek_seq, structure_seq = parsed_seqs # 用完后清除目录
313
- except Exception as e:
314
- result_html = (
315
- "<p style='color:red'><strong>Foldseek failed to extract sequence "
316
- f"from structure: {e}</strong></p>")
317
- structure_seq = ""
318
-
319
- protein_seq = structure_seq
320
- drug_input = request.form.get("drug_sequence", "")
321
- # Heuristically check if input is SMILES (not starting with [) and convert
322
- if not drug_input.strip().startswith("["):
323
- converted = smiles_to_selfies(drug_input.strip())
324
- if converted:
325
- drug_seq = converted
326
- else:
327
- drug_seq = ""
328
- result_html = "<p style='color:red'><strong>Failed to convert SMILES to SELFIES. Please check the input string.</strong></p>"
329
- else:
330
- drug_seq = drug_input
331
-
332
- elif "Inference" in request.form:
333
- protein_seq = request.form.get("protein_sequence", "")
334
- drug_seq = request.form.get("drug_sequence", "")
335
- if protein_seq and drug_seq:
336
- loader = DataLoader([(protein_seq, drug_seq, 1)], batch_size=1,
337
- collate_fn=collate_fn)
338
- feats = get_case_feature(encoding, loader)
339
- model = FusionDTI(446, 768, args).to(DEVICE)
340
- ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}",
341
- "best_model.ckpt")
342
- if os.path.isfile(ckpt):
343
- model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
344
- result_html = visualize_attention(model, feats, drug_idx)
345
-
346
- return render_template_string(
347
- # ───────────── HTML (原 UI + 新输入框) ─────────────
348
- """
349
- <!doctype html>
350
- <html lang="en"><head><meta charset="utf-8"><title>FusionDTI </title>
351
- <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&family=Poppins:wght@500;600&display=swap" rel="stylesheet">
352
-
353
- <style>
354
- :root{--bg:#f3f4f6;--card:#fff;--primary:#6366f1;--primary-dark:#4f46e5;--text:#111827;--border:#e5e7eb;}
355
- *{box-sizing:border-box;margin:0;padding:0}
356
- body{background:var(--bg);color:var(--text);font-family:Inter,system-ui,Arial,sans-serif;line-height:1.5;padding:32px 12px;}
357
- h1{font-family:Poppins,Inter,sans-serif;font-weight:600;font-size:1.7rem;text-align:center;margin-bottom:28px;letter-spacing:-.2px;}
358
- .card{max-width:1000px;margin:0 auto;background:var(--card);border:1px solid var(--border);
359
- border-radius:12px;box-shadow:0 2px 6px rgba(0,0,0,.05);padding:32px 36px;}
360
- label{font-weight:500;margin-bottom:6px;display:block}
361
- textarea,input[type=file]{width:100%;font-size:.9rem;font-family:monospace;padding:10px 12px;
362
- border:1px solid var(--border);border-radius:8px;background:#fff;resize:vertical;}
363
- textarea{min-height:90px}
364
- .btn{appearance:none;border:none;cursor:pointer;padding:12px 22px;border-radius:8px;font-weight:500;
365
- font-family:Inter,sans-serif;transition:all .18s ease;color:#fff;}
366
- .btn-primary{background:var(--primary)}.btn-primary:hover{background:var(--primary-dark)}
367
- .btn-neutral{background:#9ca3af;}.btn-neutral:hover{background:#6b7280}
368
- .grid{display:grid;gap:22px}.grid-2{grid-template-columns:1fr 1fr}
369
- .vis-box{margin-top:28px;border:1px solid var(--border);border-radius:10px;overflow:auto;max-height:72vh;}
370
- pre{white-space:pre-wrap;word-break:break-all;font-family:monospace;margin-top:8px}
371
-
372
- /* ── tidy table for Top-20 list ─────────────────────────────── */
373
- table.tg{border-collapse:collapse;margin-top:4px;font-size:0.83rem}
374
- table.tg th,table.tg td{border:1px solid var(--border);padding:6px 8px;text-align:left}
375
- table.tg th{background:var(--bg);font-weight:600}
376
- </style>
377
- </head>
378
- <body>
379
- <h1> Token-level Visualiser for Drug-Target Interaction</h1>
380
-
381
- <!-- ───────────── Project Links (larger + spaced) ───────────── -->
382
- <div style="margin-top:24px; text-align:center;">
383
- <a href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank"
384
- style="display:inline-block;margin:8px 18px;padding:10px 20px;
385
- background:linear-gradient(to right,#10b981,#059669);color:white;
386
- font-weight:600;border-radius:8px;font-size:0.9rem;
387
- font-family:Inter,sans-serif;text-decoration:none;
388
- box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
389
- onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
390
- 🌐 Project Page
391
- </a>
392
-
393
- <a href="https://arxiv.org/abs/2406.01651" target="_blank"
394
- style="display:inline-block;margin:8px 18px;padding:10px 20px;
395
- background:linear-gradient(to right,#ef4444,#dc2626);color:white;
396
- font-weight:600;border-radius:8px;font-size:0.9rem;
397
- font-family:Inter,sans-serif;text-decoration:none;
398
- box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
399
- onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
400
- 📄 ArXiv: 2406.01651
401
- </a>
402
-
403
- <a href="https://github.com/ZhaohanM/FusionDTI" target="_blank"
404
- style="display:inline-block;margin:8px 18px;padding:10px 20px;
405
- background:linear-gradient(to right,#3b82f6,#2563eb);color:white;
406
- font-weight:600;border-radius:8px;font-size:0.9rem;
407
- font-family:Inter,sans-serif;text-decoration:none;
408
- box-shadow:0 2px 6px rgba(0,0,0,0.12);transition:all 0.2s ease-in-out;"
409
- onmouseover="this.style.opacity='0.9'" onmouseout="this.style.opacity='1'">
410
- 💻 GitHub Repo
411
- </a>
412
- </div>
413
-
414
- <!-- ───────────── Guidelines for Use ───────────── -->
415
- <div class="card" style="margin-bottom:24px">
416
- <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for Use</h2>
417
- <ul style="margin-left:18px;line-height:1.55;list-style:decimal;">
418
- <li><strong>Convert protein structure into a structure-aware sequence:</strong>
419
- Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
420
- sequence will be generated using
421
- <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
422
- based on 3D structures from
423
- <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
424
- <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
425
-
426
- <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
427
- you must first visit the
428
- <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
429
- or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
430
- to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
431
-
432
- <li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
433
- You can enter a SELFIES string directly, or paste a SMILES string.
434
- SMILES will be automatically converted to SELFIES using
435
- <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
436
- If conversion fails, a red error message will be displayed.</li>
437
-
438
- <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
439
- to highlight the Top-10 interacting protein residues.</li>
440
-
441
- <li>After inference, you can use the
442
- “Download PDF” link to export a high-resolution vector version.</li>
443
- </ul>
444
- </div>
445
-
446
- <div class="card">
447
- <form method="POST" enctype="multipart/form-data" class="grid">
448
-
449
- <div><label>Protein Structure (.pdb / .cif)</label>
450
- <input type="file" name="structure_file">
451
- <input type="hidden" name="tmp_structure_path" value="{{ tmp_structure_path }}"></div>
452
-
453
- <div><label>Protein Sequence</label>
454
- <textarea name="protein_sequence" placeholder="Confirm / paste sequence…">{{ protein_seq }}</textarea></div>
455
-
456
- <div><label>Drug Sequence (SELFIES/SMILES)</label>
457
- <textarea name="drug_sequence" placeholder="[C][C][O]/cco …">{{ drug_seq }}</textarea></div>
458
-
459
- <label>Drug atom/substructure index (1-based) – show Top-10 related protein residue</label>
460
- <input type="number" name="drug_idx" min="1" style="width:120px">
461
-
462
- <div class="grid grid-2">
463
- <button class="btn btn-primary" type="Inference" name="confirm_structure">Confirm Structure</button>
464
- <button class="btn btn-primary" type="Inference" name="Inference">Inference</button>
465
- </div>
466
- <button class="btn btn-neutral" style="width:100%" type="Inference" name="clear">Clear</button>
467
- </form>
468
-
469
- {% if structure_seq %}
470
- <div style="margin-top:18px"><strong>Structure-aware sequence:</strong><pre>{{ structure_seq }}</pre></div>
471
- {% endif %}
472
- {% if result_html %}
473
- <div class="vis-box" style="margin-top:26px">{{ result_html|safe }}</div>
474
- {% endif %}
475
- </div></body></html>
476
- """,
477
- protein_seq=protein_seq, drug_seq=drug_seq, structure_seq=structure_seq,
478
- result_html=result_html, tmp_structure_path=tmp_structure_path)
479
-
480
- # ───── run ─────────────────────────────────────────────────────
481
  if __name__ == "__main__":
482
- app.run(debug=True, host="0.0.0.0", port=7860)
 
1
+ # ─── monkey-patch gradio_client so bool schemas don’t crash json_schema_to_python_type ───
2
+ import gradio_client.utils as _gc_utils
3
+
4
+ # back up originals
5
+ _orig_get_type = _gc_utils.get_type
6
+ _orig_json2py = _gc_utils._json_schema_to_python_type
7
+
8
+ def _patched_get_type(schema):
9
+ # treat any boolean schema as if it were an empty dict
10
+ if isinstance(schema, bool):
11
+ schema = {}
12
+ return _orig_get_type(schema)
13
+
14
+ def _patched_json_schema_to_python_type(schema, defs=None):
15
+ # treat any boolean schema as if it were an empty dict
16
+ if isinstance(schema, bool):
17
+ schema = {}
18
+ return _orig_json2py(schema, defs)
19
+
20
+ _gc_utils.get_type = _patched_get_type
21
+ _gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
22
+
23
+ # ─── now it’s safe to import Gradio and build your interface ───────────────────────────
24
+ import gradio as gr
25
+
26
+ import os
27
+ import sys
28
+ import argparse
29
+ import tempfile
30
+ import shutil
31
+ import base64
32
+ import io
33
 
34
  import torch
35
+ import selfies
36
+ from rdkit import Chem
37
  import matplotlib
38
  matplotlib.use("Agg")
39
  import matplotlib.pyplot as plt
40
  from matplotlib import cm
41
  from typing import Optional
42
 
 
43
  from transformers import EsmForMaskedLM, EsmTokenizer, AutoModel
44
+ from torch.utils.data import DataLoader
45
+ from Bio.PDB import PDBParser, MMCIFParser
46
+ from Bio.Data import IUPACData
47
+
48
+ from utils.drug_tokenizer import DrugTokenizer
49
  from utils.metric_learning_models_att_maps import Pre_encoded, FusionDTI
50
  from utils.foldseek_util import get_struc_seq
51
 
52
+ # ───── Helpers ─────────────────────────────────────────────────
53
+
54
+ three2one = {k.upper(): v for k, v in IUPACData.protein_letters_3to1.items()}
55
+ three2one.update({"MSE": "M", "SEC": "C", "PYL": "K"})
56
+ def simple_seq_from_structure(path: str) -> str:
57
+ parser = MMCIFParser(QUIET=True) if path.endswith(".cif") else PDBParser(QUIET=True)
58
+ structure = parser.get_structure("P", path)
59
+ chains = list(structure.get_chains())
60
+ if not chains:
61
+ return ""
62
+ chain = max(chains, key=lambda c: len(list(c.get_residues())))
63
+ return "".join(three2one.get(res.get_resname().upper(), "X") for res in chain)
64
+
65
+ def smiles_to_selfies(smiles: str) -> Optional[str]:
66
+ try:
67
+ mol = Chem.MolFromSmiles(smiles)
68
+ if mol is None:
69
+ return None
70
+ return selfies.encoder(smiles)
71
+ except:
72
+ return None
73
 
74
  def parse_config():
75
  p = argparse.ArgumentParser()
 
76
  p.add_argument("--prot_encoder_path", default="westlake-repl/SaProt_650M_AF2")
77
  p.add_argument("--drug_encoder_path", default="HUBioDataLab/SELFormer")
78
+ p.add_argument("--agg_mode", type=str, default="mean_all_tok")
79
  p.add_argument("--group_size", type=int, default=1)
 
80
  p.add_argument("--fusion", default="CAN")
81
  p.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
82
  p.add_argument("--save_path_prefix", default="save_model_ckp/")
 
86
  args = parse_config()
87
  DEVICE = args.device
88
 
89
+ # ───── Load models & tokenizers ─────────────────────────────────
90
  prot_tokenizer = EsmTokenizer.from_pretrained(args.prot_encoder_path)
91
  prot_model = EsmForMaskedLM.from_pretrained(args.prot_encoder_path)
92
+ drug_tokenizer = DrugTokenizer()
 
93
  drug_model = AutoModel.from_pretrained(args.drug_encoder_path)
94
+ encoding = Pre_encoded(prot_model, drug_model, args).to(DEVICE)
95
 
 
 
 
96
  def collate_fn(batch):
97
  query1, query2, scores = zip(*batch)
98
 
 
118
  attention_mask2 = query_encodings2["attention_mask"].bool()
119
 
120
  return query_encodings1["input_ids"], attention_mask1, query_encodings2["input_ids"], attention_mask2, scores
 
 
 
 
 
 
 
 
 
 
 
121
 
122
 
 
123
  def get_case_feature(model, loader):
124
  model.eval()
125
  with torch.no_grad():
 
131
  p_ids.cpu(), d_ids.cpu(),
132
  p_mask.cpu(), d_mask.cpu(), None)]
133
 
134
+
135
+ # ─────────────── visualisation ───────────────────────────────────────────
 
 
 
 
 
136
  def visualize_attention(model, feats, drug_idx: Optional[int] = None) -> str:
137
  """
138
  Render a Protein → Drug cross-attention heat-map and, optionally, a
139
+ Top-30 protein-residue table for a chosen drug-token index.
140
 
141
  The token index shown on the x-axis (and accepted via *drug_idx*) is **the
142
  position of that token in the *original* drug sequence**, *after* the
 
235
  plt.close(fig)
236
  html = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" />'
237
 
238
+ # ───────────────────── Top-30 tabel ─────────────────────
239
+ table_html = ""
240
  if drug_idx is not None:
241
  # map original 0-based drug_idx → current column position
242
  if (drug_idx + 1) in d_indices:
 
248
 
249
  if col_pos is not None:
250
  col_vec = attn[:, col_pos]
251
+ topk = torch.topk(col_vec, k=min(30, len(col_vec))).indices.tolist()
252
 
253
  rank_hdr = "".join(f"<th>{r+1}</th>" for r in range(len(topk)))
254
  res_row = "".join(f"<td>{p_tokens[i]}</td>" for i in topk)
 
256
 
257
  drug_tok_text = d_tokens[col_pos]
258
  orig_idx = d_indices[col_pos]
259
+
260
+ # 1) build the header row: leading “Rank”, then 1…30
261
+ header_cells = (
262
+ "<th style='border:1px solid #ccc; padding:6px; "
263
+ "background:#f7f7f7; text-align:center;'>Rank</th>"
264
+ + "".join(
265
+ f"<th style='border:1px solid #ccc; padding:6px; "
266
+ f"background:#f7f7f7; text-align:center'>{r+1}</th>"
267
+ for r in range(len(topk))
268
+ )
269
+ )
270
+
271
+ # 2) build the residue row: leading “Residue”, then the residue tokens
272
+ residue_cells = (
273
+ "<th style='border:1px solid #ccc; padding:6px; "
274
+ "background:#f7f7f7; text-align:center;'>Residue</th>"
275
+ + "".join(
276
+ f"<td style='border:1px solid #ccc; padding:6px; "
277
+ f"text-align:center'>{p_tokens[i]}</td>"
278
+ for i in topk
279
+ )
280
+ )
281
+
282
+ # 3) build the position row: leading “Position”, then the residue positions
283
+ position_cells = (
284
+ "<th style='border:1px solid #ccc; padding:6px; "
285
+ "background:#f7f7f7; text-align:center;'>Position</th>"
286
+ + "".join(
287
+ f"<td style='border:1px solid #ccc; padding:6px; "
288
+ f"text-align:center'>{p_indices[i]}</td>"
289
+ for i in topk
290
+ )
291
+ )
292
+
293
+ # 4) assemble your table_html
294
  table_html = (
295
+ f"<h4 style='margin-bottom:12px'>"
296
+ f"Drug atom #{orig_idx} <code>{drug_tok_text}</code> → Top-30 Protein residues"
297
+ f"</h4>"
298
+ f"<table style='border-collapse:collapse; margin:0 auto 24px;'>"
299
+ f"<tr>{header_cells}</tr>"
300
+ f"<tr>{residue_cells}</tr>"
301
+ f"<tr>{position_cells}</tr>"
302
+ f"</table>"
303
+ )
304
+
305
  buf_png = io.BytesIO()
306
+ fig.savefig(buf_png, format="png", dpi=140)
307
  buf_png.seek(0)
308
 
309
  buf_pdf = io.BytesIO()
310
+ fig.savefig(buf_pdf, format="pdf")
311
  buf_pdf.seek(0)
312
  plt.close(fig)
313
 
 
315
  pdf_b64 = base64.b64encode(buf_pdf.getvalue()).decode()
316
 
317
  html_heat = (
318
+ f"<div style='position: relative; width: 100%;'>"
319
+ # the PDF button, absolutely positioned
320
+ f"<a href='data:application/pdf;base64,{pdf_b64}' download='attention_heatmap.pdf' "
321
+ "style='position: absolute; top: 12px; right: 12px; "
322
+ "background: var(--primary); color: #fff; "
323
+ "padding: 8px 16px; border-radius: 6px; "
324
+ "font-size: 0.9rem; font-weight: 500; "
325
+ "text-decoration: none;'>"
326
+ "Download PDF"
327
+ "</a>"
328
+ # the clickable heat‐map image
329
+ f"<a href='data:image/png;base64,{png_b64}' target='_blank' title='Click to enlarge'>"
330
+ f"<img src='data:image/png;base64,{png_b64}' "
331
+ "style='display: block; width: 100%; height: auto; cursor: zoom-in;'/>"
332
+ "</a>"
333
+ "</div>"
334
  )
335
 
 
336
  return table_html + html_heat
337
+
338
+ # ───── Gradio Callbacks ─────────────────────────────────────────
339
+
340
+ ROOT = os.path.dirname(os.path.abspath(__file__))
341
+ FOLDSEEK_BIN = os.path.join(ROOT, "bin", "foldseek")
342
+
343
+ def extract_sequence_cb(structure_file):
344
+ if structure_file is None or not os.path.exists(structure_file.name):
345
+ return ""
346
+ parsed = get_struc_seq(FOLDSEEK_BIN, structure_file.name, None, plddt_mask=False)
347
+ first_chain = next(iter(parsed))
348
+ _, _, struct_seq = parsed[first_chain]
349
+ return struct_seq
350
+
351
+ def inference_cb(prot_seq, drug_seq, atom_idx):
352
+ if not prot_seq:
353
+ return "<p style='color:red'>Please extract or enter a protein sequence first.</p>"
354
+ if not drug_seq.strip():
355
+ return "<p style='color:red'>Please enter a drug sequence.</p>"
356
+ if not drug_seq.strip().startswith("["):
357
+ conv = smiles_to_selfies(drug_seq.strip())
358
+ if conv is None:
359
+ return "<p style='color:red'>SMILES→SELFIES conversion failed.</p>"
360
+ drug_seq = conv
361
+ loader = DataLoader([(prot_seq, drug_seq, 1)], batch_size=1, collate_fn=collate_fn)
362
+ feats = get_case_feature(encoding, loader)
363
+ model = FusionDTI(446, 768, args).to(DEVICE)
364
+ ckpt = os.path.join(f"{args.save_path_prefix}{args.dataset}_{args.fusion}", "best_model.ckpt")
365
+ if os.path.isfile(ckpt):
366
+ model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
367
+ return visualize_attention(model, feats, int(atom_idx)-1 if atom_idx else None)
368
+
369
+ def clear_cb():
370
+ return None, "", "", None, ""
371
+
372
+ # ───── Gradio Interface Definition ────────���──────────────────────
373
+
374
+ css = """
375
+ :root {
376
+ --bg: #f3f4f6;
377
+ --card: #ffffff;
378
+ --border: #e5e7eb;
379
+ --primary: #6366f1;
380
+ --primary-dark: #4f46e5;
381
+ --text: #111827;
382
+ }
383
+ * { box-sizing: border-box; margin: 0; padding: 0; }
384
+ body { background: var(--bg); color: var(--text); font-family: Inter,system-ui,Arial,sans-serif; }
385
+ h1 { font-family: Poppins,Inter,sans-serif; font-weight: 600; font-size: 2rem; text-align: center; margin: 24px 0; }
386
+ button, .gr-button { font-family: Inter,sans-serif; font-weight: 600; }
387
+ #project-links { text-align: center; margin-bottom: 32px; }
388
+ #project-links .gr-button { margin: 0 8px; min-width: 160px; }
389
+ #project-links .gr-button:nth-child(1) { background: #10b981; }
390
+ #project-links .gr-button:nth-child(2) { background: #ef4444; }
391
+ #project-links .gr-button:nth-child(3) { background: #3b82f6; }
392
+ #project-links .gr-button:hover { opacity: 0.9; }
393
+ .link-btn{display:inline-block;margin:0 8px;padding:10px 20px;border-radius:8px;
394
+ color:white;font-weight:600;text-decoration:none;box-shadow:0 2px 6px rgba(0,0,0,0.12);
395
+ transition:all .2s ease-in-out;}
396
+ .link-btn:hover{opacity:.9;}
397
+ .link-btn.project{background:linear-gradient(to right,#10b981,#059669);}
398
+ .link-btn.arxiv {background:linear-gradient(to right,#ef4444,#dc2626);}
399
+ .link-btn.github {background:linear-gradient(to right,#3b82f6,#2563eb);}
400
+
401
+ /* make *all* gradio buttons a bit taller */
402
+ .gr-button { min-height: 10px !important; }
403
+
404
+ /* now target just our two big action buttons */
405
+ #extract-btn, #inference-btn {
406
+ width: 5px !important;
407
+ min-height: 36px !important;
408
+ margin-top: 12px !important;
409
+ }
410
+
411
+ /* and make clear button full width but shorter */
412
+ #clear-btn {
413
+ width: 10px !important;
414
+ min-height: 36px !important;
415
+ margin-top: 12px !important;
416
+ }
417
+
418
+ #input-card label {
419
+ font-weight: 600 !important; /* make the text bold */
420
+ color: var(--text) !important; /* use your standard text color */
421
+ }
422
+
423
+ .card {
424
+ background: var(--card);
425
+ border: 1px solid var(--border);
426
+ border-radius: 12px;
427
+ padding: 24px;
428
+ max-width: 1000px;
429
+ margin: 0 auto 32px;
430
+ box-shadow: 0 2px 6px rgba(0,0,0,0.05);
431
+ }
432
+
433
+ #guidelines-card h2 {
434
+ font-size: 1.4rem;
435
+ margin-bottom: 16px;
436
+ text-align: center;
437
+ }
438
+ #guidelines-card ol {
439
+ margin-left: 20px;
440
+ line-height: 1.6;
441
+ font-size: 1rem;
442
+ }
443
+ #input-card .gr-row, #input-card .gr-cols {
444
+ gap: 16px;
445
+ }
446
+ #input-card .gr-button {
447
+ flex: 1;
448
+ }
449
+ #output-card {
450
+ padding-top: 0;
451
+ }
452
+ """
453
+
454
+ with gr.Blocks(css=css) as demo:
455
+ # ───────────── Title ─────────────
456
+ gr.Markdown("<h1>Token-level Visualiser for Drug-Target Interaction</h1>")
457
+
458
+ # ───────────── Project Links ─────────────
459
+ gr.Markdown("""
460
+ <div style="text-align:center;margin-bottom:32px;">
461
+ <a class="link-btn project" href="https://zhaohanm.github.io/FusionDTI.github.io/" target="_blank">🌐 Project Page</a>
462
+ <a class="link-btn arxiv" href="https://arxiv.org/abs/2406.01651" target="_blank">📄 ArXiv: 2406.01651</a>
463
+ <a class="link-btn github" href="https://github.com/ZhaohanM/FusionDTI" target="_blank">💻 GitHub Repo</a>
464
+ </div>
465
+ """)
466
+ # ───────────── Guidelines Card ─────────────
467
 
468
+ gr.HTML(
469
+ """
470
+ <div class="card" style="margin-bottom:24px">
471
+ <h2 style="font-size:1.2rem;margin-bottom:14px">Guidelines for User</h2>
472
+ <ul style="font-size:1rem; margin-left:18px;line-height:1.55;list-style:decimal;">
473
+ <li><strong>Convert protein structure into a structure-aware sequence:</strong>
474
+ Upload a <code>.pdb</code> or <code>.cif</code> file. A structure-aware
475
+ sequence will be generated using
476
+ <a href="https://github.com/steineggerlab/foldseek" target="_blank">Foldseek</a>,
477
+ based on 3D structures from
478
+ <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a> or the
479
+ <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>.</li>
480
+
481
+ <li><strong>If you only have an amino acid sequence or a UniProt ID,</strong>
482
+ you must first visit the
483
+ <a href="https://www.rcsb.org" target="_blank">Protein Data Bank (PDB)</a>
484
+ or <a href="https://alphafold.ebi.ac.uk" target="_blank">AlphaFold&nbsp;DB</a>
485
+ to search and download the corresponding <code>.cif</code> or <code>.pdb</code> file.</li>
486
+
487
+ <li><strong>Drug input supports both SELFIES and SMILES:</strong><br>
488
+ You can enter a SELFIES string directly, or paste a SMILES string.
489
+ SMILES will be automatically converted to SELFIES using
490
+ <a href="https://github.com/aspuru-guzik-group/selfies" target="_blank">SELFIES encoder</a>.
491
+ If conversion fails, a red error message will be displayed.</li>
492
+
493
+ <li>Optionally enter a <strong>1-based</strong> drug atom or substructure index
494
+ to highlight the Top-10 interacting protein residues.</li>
495
+
496
+ <li>After inference, you can use the
497
+ “Download PDF” link to export a high-resolution vector version.</li>
498
+ </ul>
499
+ </div>
500
+ """)
501
+
502
+ # ───────────── Input Card ─────────────
503
+ with gr.Column(elem_id="input-card", elem_classes="card"):
504
+
505
+ protein_seq = gr.Textbox(
506
+ label="Protein Structure-aware Sequence",
507
+ lines=3,
508
+ elem_id="protein-seq"
509
+ )
510
+ drug_seq = gr.Textbox(
511
+ label="Drug Sequence (SELFIES/SMILES)",
512
+ lines=3,
513
+ elem_id="drug-seq"
514
+ )
515
+ structure_file = gr.File(
516
+ label="Upload Protein Structure (.pdb/.cif)",
517
+ file_types=[".pdb", ".cif"],
518
+ elem_id="structure-file"
519
+ )
520
+ drug_idx = gr.Number(
521
+ label="Drug atom/substructure index (1-based)",
522
+ value=None,
523
+ precision=0,
524
+ elem_id="drug-idx"
525
+ )
526
+
527
+ # ───────────── Action Buttons ─────────────
528
+ with gr.Row(elem_id="action-buttons", equal_height=True):
529
+ btn_extract = gr.Button(
530
+ "Extract sequence",
531
+ variant="primary",
532
+ elem_id="extract-btn"
533
+ )
534
+ btn_infer = gr.Button(
535
+ "Inference",
536
+ variant="primary",
537
+ elem_id="inference-btn"
538
+ )
539
+ with gr.Row():
540
+ clear_btn = gr.Button(
541
+ "Clear",
542
+ variant="secondary",
543
+ elem_classes="full-width",
544
+ elem_id="clear-btn"
545
+ )
546
+
547
+ # ───────────── Output Visualization ─────────────
548
+ output_html = gr.HTML(elem_id="result-html")
549
+
550
+ # ───────────── Event Wiring ─────────────
551
+ btn_extract.click(
552
+ fn=extract_sequence_cb,
553
+ inputs=[structure_file],
554
+ outputs=[protein_seq]
555
+ )
556
+ btn_infer.click(
557
+ fn=inference_cb,
558
+ inputs=[protein_seq, drug_seq, drug_idx],
559
+ outputs=[output_html]
560
+ )
561
+ clear_btn.click(
562
+ fn=lambda: ("", "", None, "", ""),
563
+ inputs=[],
564
+ outputs=[protein_seq, drug_seq, drug_idx, output_html, structure_file]
565
+ )
566
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
567
  if __name__ == "__main__":
568
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)