File size: 1,370 Bytes
2b41fd4 cf6b570 66ec8de cf6b570 2b41fd4 cf6b570 66ec8de cf6b570 66ec8de cf6b570 2b41fd4 cf6b570 66ec8de cf6b570 66ec8de cf6b570 66ec8de cf6b570 66ec8de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
from transformers import AutoTokenizer, AutoModel
import gradio as gr
# Load Bio_ClinicalBERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model.eval() # Disable dropout for inference
# Attention-masked mean pooling function
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state # (batch_size, seq_len, hidden_dim)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
pooled = torch.sum(token_embeddings * input_mask_expanded, dim=1) / \
torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
return pooled
# Embedding function
def embed_text(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
embedding = mean_pooling(outputs, inputs['attention_mask']).squeeze().tolist()
return embedding
# Gradio interface
iface = gr.Interface(
fn=embed_text,
inputs=gr.Textbox(lines=5, label="Enter clinical text"),
outputs="json",
title="Clinical Text Embedding API",
description="Generate dense vector embeddings using Bio_ClinicalBERT with attention-masked mean pooling"
)
iface.launch()
|