|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
import gradio as gr |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") |
|
model.eval() |
|
|
|
|
|
def mean_pooling(model_output, attention_mask): |
|
token_embeddings = model_output.last_hidden_state |
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
pooled = torch.sum(token_embeddings * input_mask_expanded, dim=1) / \ |
|
torch.clamp(input_mask_expanded.sum(dim=1), min=1e-9) |
|
return pooled |
|
|
|
|
|
def embed_text(text): |
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
embedding = mean_pooling(outputs, inputs['attention_mask']).squeeze().tolist() |
|
return embedding |
|
|
|
|
|
iface = gr.Interface( |
|
fn=embed_text, |
|
inputs=gr.Textbox(lines=5, label="Enter clinical text"), |
|
outputs="json", |
|
title="Clinical Text Embedding API", |
|
description="Generate dense vector embeddings using Bio_ClinicalBERT with attention-masked mean pooling" |
|
) |
|
|
|
iface.launch() |
|
|