s1ome123 commited on
Commit
cf6b570
·
verified ·
1 Parent(s): 97e5541

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -1,24 +1,35 @@
1
- from transformers import AutoTokenizer, AutoModel
2
  import torch
 
3
  import gradio as gr
4
 
5
- # Load Bio_ClinicalBERT
6
  tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
7
  model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
 
 
 
 
 
 
 
 
 
8
 
 
9
  def embed_text(text):
10
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
11
  with torch.no_grad():
12
  outputs = model(**inputs)
13
- # Mean pooling
14
- embedding = outputs.last_hidden_state.mean(dim=1).squeeze().tolist()
15
  return embedding
16
 
 
17
  iface = gr.Interface(
18
  fn=embed_text,
19
- inputs=gr.Textbox(lines=5, label="Enter patient text"),
20
  outputs="json",
21
- title="Clinical Text Embedding API (Bio_ClinicalBERT)"
 
22
  )
23
 
24
  iface.launch()
 
 
1
  import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
  import gradio as gr
4
 
5
+ # Load Bio_ClinicalBERT model and tokenizer
6
  tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
7
  model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
8
+ model.eval() # Disable dropout for inference
9
+
10
+ # Attention-masked mean pooling function
11
+ def mean_pooling(model_output, attention_mask):
12
+ token_embeddings = model_output.last_hidden_state # (batch_size, seq_len, hidden_dim)
13
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
14
+ pooled = torch.sum(token_embeddings * input_mask_expanded, dim=1) / \
15
+ torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
16
+ return pooled
17
 
18
+ # Embedding function
19
  def embed_text(text):
20
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
21
  with torch.no_grad():
22
  outputs = model(**inputs)
23
+ embedding = mean_pooling(outputs, inputs['attention_mask']).squeeze().tolist()
 
24
  return embedding
25
 
26
+ # Gradio interface
27
  iface = gr.Interface(
28
  fn=embed_text,
29
+ inputs=gr.Textbox(lines=5, label="Enter clinical text"),
30
  outputs="json",
31
+ title="Clinical Text Embedding API",
32
+ description="Generate dense vector embeddings using Bio_ClinicalBERT with attention-masked mean pooling"
33
  )
34
 
35
  iface.launch()