File size: 1,370 Bytes
2b41fd4
cf6b570
66ec8de
 
cf6b570
2b41fd4
 
cf6b570
 
 
 
 
 
 
 
 
66ec8de
cf6b570
66ec8de
cf6b570
2b41fd4
 
cf6b570
66ec8de
 
cf6b570
66ec8de
 
cf6b570
66ec8de
cf6b570
 
66ec8de
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from transformers import AutoTokenizer, AutoModel
import gradio as gr

# Load Bio_ClinicalBERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model.eval()  # Disable dropout for inference

# Attention-masked mean pooling function
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output.last_hidden_state  # (batch_size, seq_len, hidden_dim)
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    pooled = torch.sum(token_embeddings * input_mask_expanded, dim=1) / \
             torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
    return pooled

# Embedding function
def embed_text(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    embedding = mean_pooling(outputs, inputs['attention_mask']).squeeze().tolist()
    return embedding

# Gradio interface
iface = gr.Interface(
    fn=embed_text,
    inputs=gr.Textbox(lines=5, label="Enter clinical text"),
    outputs="json",
    title="Clinical Text Embedding API",
    description="Generate dense vector embeddings using Bio_ClinicalBERT with attention-masked mean pooling"
)

iface.launch()