Drug_assisstant / src /models /biomedical_llm.py
oussamaor's picture
Upload 12 files
f368eec verified
"""Biomedical Language Model for drug interaction analysis."""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json
import re
class BiomedicalLLM:
"""Class to handle biomedical language model inference"""
def __init__(self, model_name="stanford-crfm/BioMedLM"):
"""
Initialize the Biomedical Language Model
Args:
model_name: The name of the model to use (default: BioMedLM)
Options include:
- "stanford-crfm/BioMedLM"
- "microsoft/biogpt"
"""
self.model_name = model_name
try:
print(f"Loading {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set pad token if it doesn't exist
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
pad_token_id=self.tokenizer.pad_token_id
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Model loaded successfully on {self.device}")
except Exception as e:
print(f"Error loading model: {e}")
print("Falling back to API-based approach or stub implementation")
self.tokenizer = None
self.model = None
def extract_ddi_from_literature(self, drug1, drug2):
"""
Extract drug-drug interaction information from biomedical literature
Args:
drug1: Name of the first drug
drug2: Name of the second drug
Returns:
A list of extracted interactions with evidence
"""
if self.model is None or self.tokenizer is None:
# Fallback behavior if model failed to load
return self._fallback_extract_ddi(drug1, drug2)
try:
# Construct a prompt for the model
prompt = f"""
Analyze the scientific literature for interactions between {drug1} and {drug2}.
Include the following information:
1. Description of the interaction mechanism
2. Severity (Mild, Moderate, Severe)
3. Clinical significance
4. Management recommendations
Format the response as JSON with the following structure:
{{
"interactions": [
{{
"description": "Description of mechanism",
"severity": "Severity level",
"evidence": "Evidence from literature",
"management": "Management recommendation"
}}
]
}}
"""
# Generate completion with proper attention mask
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Ensure attention mask is set properly
if 'attention_mask' not in inputs:
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
# Generate with appropriate parameters
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
# Decode the response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the JSON part of the response
json_match = re.search(r'({[\s\S]*})', response)
if json_match:
json_str = json_match.group(1)
try:
return json.loads(json_str)
except:
# If JSON parsing fails, return a structured response anyway
return self._extract_structured_info(response, drug1, drug2)
else:
# If no JSON found, extract structured information
return self._extract_structured_info(response, drug1, drug2)
except Exception as e:
print(f"Error in LLM inference: {e}")
return self._fallback_extract_ddi(drug1, drug2)
def _extract_structured_info(self, text, drug1, drug2):
"""Extract structured information from text if JSON parsing fails"""
# Try to identify descriptions, severity, etc.
severity_match = re.search(r'(mild|moderate|severe)', text.lower())
severity = severity_match.group(1).capitalize() if severity_match else "Unknown"
# Default structured response
return {
"interactions": [
{
"description": f"Potential interaction between {drug1} and {drug2} identified in literature",
"severity": severity,
"evidence": "Based on biomedical literature analysis",
"management": "Consult healthcare provider for specific guidance"
}
]
}
def _fallback_extract_ddi(self, drug1, drug2):
"""Fallback method when model is not available"""
# Return a structured response with disclaimer
return {
"interactions": [
{
"description": f"Potential interaction between {drug1} and {drug2}",
"severity": "Unknown",
"evidence": "Please consult literature for evidence",
"management": "Consult healthcare provider for guidance"
}
],
"note": "Biomedical model not available - using fallback information"
}
def analyze_clinical_notes(self, clinical_text):
"""
Extract drug mentions and potential interactions from clinical notes
Args:
clinical_text: The clinical notes text to analyze
Returns:
A dictionary with extracted drugs and potential interactions
"""
if self.model is None or self.tokenizer is None:
# Fallback behavior if model failed to load
return self._fallback_analyze_clinical_notes(clinical_text)
try:
# Construct a prompt for the model
prompt = f"""
Extract all medication mentions and potential drug interactions from the following clinical note:
{clinical_text}
Format the response as JSON with the following structure:
{{
"medications": [
{{
"name": "Drug name",
"dosage": "Dosage if mentioned",
"frequency": "Frequency if mentioned"
}}
],
"potential_interactions": [
{{
"drug1": "First drug name",
"drug2": "Second drug name",
"concern": "Description of potential interaction"
}}
]
}}
"""
# Generate completion with proper attention mask
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Ensure attention mask is set properly
if 'attention_mask' not in inputs:
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
# Generate with appropriate parameters
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
# Decode the response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the JSON part of the response
json_match = re.search(r'({[\s\S]*})', response)
if json_match:
json_str = json_match.group(1)
try:
return json.loads(json_str)
except:
# If JSON parsing fails, return a structured response anyway
return self._extract_medications_from_text(response)
else:
# If no JSON found, extract structured information
return self._extract_medications_from_text(response)
except Exception as e:
print(f"Error in LLM inference: {e}")
return self._fallback_analyze_clinical_notes(clinical_text)
def _extract_medications_from_text(self, text):
"""Extract medication mentions from text if JSON parsing fails"""
# Simple regex-based extraction
drug_patterns = [
r'([A-Za-z]+)\s+(\d+\s*mg)',
r'([A-Za-z]+)\s+(\d+\s*mcg)',
r'([A-Za-z]+)\s+(\d+\s*ml)',
r'([A-Za-z]+)\s+(\d+\s*tablet)',
r'prescribe[d]?\s+([A-Za-z]+)',
r'taking\s+([A-Za-z]+)',
r'administer[ed]?\s+([A-Za-z]+)'
]
medications = []
for pattern in drug_patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
if len(match.groups()) > 1:
drug_name = match.group(1)
dosage = match.group(2)
medications.append({"name": drug_name, "dosage": dosage, "frequency": "Not specified"})
else:
drug_name = match.group(1)
medications.append({"name": drug_name, "dosage": "Not specified", "frequency": "Not specified"})
# Return structured data
return {
"medications": medications,
"potential_interactions": []
}
def _fallback_analyze_clinical_notes(self, clinical_text):
"""Fallback method for clinical note analysis when model is not available"""
# Return a structured response with disclaimer
return {
"medications": [],
"potential_interactions": [],
"note": "Biomedical model not available - please review clinical notes manually"
}
def get_drug_information(self, drug_name):
"""
Get detailed information about a specific drug
Args:
drug_name: Name of the drug
Returns:
A dictionary with drug information
"""
if self.model is None or self.tokenizer is None:
# Fallback behavior if model failed to load
return self._fallback_drug_information(drug_name)
try:
# Construct a prompt for the model
prompt = f"""
Provide comprehensive information about the medication {drug_name}, including:
1. Drug class
2. Mechanism of action
3. Common indications
4. Common side effects
5. Common drug interactions
6. Contraindications
Format the response as JSON with the following structure:
{{
"drug_name": "{drug_name}",
"drug_class": "Drug class",
"mechanism": "Mechanism of action",
"indications": ["Indication 1", "Indication 2"],
"side_effects": ["Side effect 1", "Side effect 2"],
"common_interactions": ["Drug 1", "Drug 2"],
"contraindications": ["Contraindication 1", "Contraindication 2"]
}}
"""
# Generate completion with proper attention mask
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
# Ensure attention mask is set properly
if 'attention_mask' not in inputs:
inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
# Generate with appropriate parameters
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
# Decode the response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the JSON part of the response
json_match = re.search(r'({[\s\S]*})', response)
if json_match:
json_str = json_match.group(1)
try:
return json.loads(json_str)
except:
# If JSON parsing fails, return a structured response anyway
return self._extract_drug_info_from_text(response, drug_name)
else:
# If no JSON found, extract structured information
return self._extract_drug_info_from_text(response, drug_name)
except Exception as e:
print(f"Error in LLM inference: {e}")
return self._fallback_drug_information(drug_name)
def _extract_drug_info_from_text(self, text, drug_name):
"""Extract drug information from text if JSON parsing fails"""
# Create default structure
drug_info = {
"drug_name": drug_name,
"drug_class": "Not specified",
"mechanism": "Not specified",
"indications": [],
"side_effects": [],
"common_interactions": [],
"contraindications": []
}
# Try to extract each section
class_match = re.search(r'[Cc]lass:?\s+([^\n\.]+)', text)
if class_match:
drug_info["drug_class"] = class_match.group(1).strip()
mechanism_match = re.search(r'[Mm]echanism:?\s+([^\n\.]+)', text)
if mechanism_match:
drug_info["mechanism"] = mechanism_match.group(1).strip()
# Extract lists with regex
indication_match = re.search(r'[Ii]ndications?:?\s+((?:[^\n]+\n?)+)', text)
if indication_match:
indications_text = indication_match.group(1)
# Split by common list markers
indications = re.findall(r'(?:^|\n)\s*(?:\d+\.|\*|-|•)\s*([^\n]+)', indications_text)
if indications:
drug_info["indications"] = [ind.strip() for ind in indications]
elif indications_text:
# If no list markers, just split by commas or newlines
items = re.split(r',|\n', indications_text)
drug_info["indications"] = [item.strip() for item in items if item.strip()]
# Similarly for other list-based fields
side_effects_match = re.search(r'[Ss]ide [Ee]ffects:?\s+((?:[^\n]+\n?)+)', text)
if side_effects_match:
side_effects_text = side_effects_match.group(1)
side_effects = re.findall(r'(?:^|\n)\s*(?:\d+\.|\*|-|•)\s*([^\n]+)', side_effects_text)
if side_effects:
drug_info["side_effects"] = [se.strip() for se in side_effects]
elif side_effects_text:
items = re.split(r',|\n', side_effects_text)
drug_info["side_effects"] = [item.strip() for item in items if item.strip()]
interactions_match = re.search(r'[Ii]nteractions:?\s+((?:[^\n]+\n?)+)', text)
if interactions_match:
interactions_text = interactions_match.group(1)
interactions = re.findall(r'(?:^|\n)\s*(?:\d+\.|\*|-|•)\s*([^\n]+)', interactions_text)
if interactions:
drug_info["common_interactions"] = [inter.strip() for inter in interactions]
elif interactions_text:
items = re.split(r',|\n', interactions_text)
drug_info["common_interactions"] = [item.strip() for item in items if item.strip()]
contraindications_match = re.search(r'[Cc]ontraindications:?\s+((?:[^\n]+\n?)+)', text)
if contraindications_match:
contraindications_text = contraindications_match.group(1)
contraindications = re.findall(r'(?:^|\n)\s*(?:\d+\.|\*|-|•)\s*([^\n]+)', contraindications_text)
if contraindications:
drug_info["contraindications"] = [contra.strip() for contra in contraindications]
elif contraindications_text:
items = re.split(r',|\n', contraindications_text)
drug_info["contraindications"] = [item.strip() for item in items if item.strip()]
return drug_info
def _fallback_drug_information(self, drug_name):
"""Fallback method for drug information when model is not available"""
# Return a structured response with disclaimer
return {
"drug_name": drug_name,
"drug_class": "Information not available",
"mechanism": "Information not available",
"indications": ["Information not available"],
"side_effects": ["Information not available"],
"common_interactions": ["Information not available"],
"contraindications": ["Information not available"],
"note": "Biomedical model not available - using fallback information"
}