Jesudian's picture
Upload 6 files
ceea46c verified
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from utils import DocumentProcessor
import os
from typing import List, Dict
class HealthcareRAG:
def __init__(self,
model_name: str = "google/flan-t5-base",
index_path: str = "faiss_index.bin",
pdf_path: str = "clinical_guidelines.pdf"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
self.doc_processor = DocumentProcessor()
# Initialize or load FAISS index
if os.path.exists(index_path):
self.index = self.doc_processor.load_index(index_path)
# Load chunks from a saved file (you might want to save/load chunks separately)
self.chunks = [] # Load your chunks here
else:
self.chunks, self.index = self.doc_processor.process_document(pdf_path, index_path)
def generate_response(self, query: str, retrieved_chunks: List[str]) -> str:
"""Generate response using the LLM."""
if not retrieved_chunks:
return "This information is not available in the current guidelines."
# Prepare context and prompt
context = "\n".join(retrieved_chunks)
prompt = f"""Based on the following clinical guidelines, answer the question.
If the information is not explicitly stated in the guidelines, respond with "This information is not available in the current guidelines."
Guidelines:
{context}
Question: {query}
Answer:"""
# Generate response
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(self.device)
outputs = self.model.generate(
**inputs,
max_length=200,
num_beams=4,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Add disclaimer
if response and response != "This information is not available in the current guidelines.":
response += "\n\nDISCLAIMER: This information is for educational purposes only and not medical advice."
return response
def query(self, user_query: str) -> Dict[str, str]:
"""Process user query and return response with retrieved chunks."""
# Retrieve relevant chunks
retrieved_chunks = self.doc_processor.retrieve_chunks(
user_query,
self.index,
self.chunks
)
# Generate response
response = self.generate_response(user_query, retrieved_chunks)
return {
"response": response,
"retrieved_chunks": retrieved_chunks
}