embedding / app.py
s1ome123's picture
Update app.py
cf6b570 verified
raw
history blame
1.37 kB
import torch
from transformers import AutoTokenizer, AutoModel
import gradio as gr
# Load Bio_ClinicalBERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model.eval() # Disable dropout for inference
# Attention-masked mean pooling function
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state # (batch_size, seq_len, hidden_dim)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
pooled = torch.sum(token_embeddings * input_mask_expanded, dim=1) / \
torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9)
return pooled
# Embedding function
def embed_text(text):
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
embedding = mean_pooling(outputs, inputs['attention_mask']).squeeze().tolist()
return embedding
# Gradio interface
iface = gr.Interface(
fn=embed_text,
inputs=gr.Textbox(lines=5, label="Enter clinical text"),
outputs="json",
title="Clinical Text Embedding API",
description="Generate dense vector embeddings using Bio_ClinicalBERT with attention-masked mean pooling"
)
iface.launch()