from transformers import AutoModelForSeq2SeqLM, AutoTokenizer import torch from utils import DocumentProcessor import os from typing import List, Dict class HealthcareRAG: def __init__(self, model_name: str = "google/flan-t5-base", index_path: str = "faiss_index.bin", pdf_path: str = "clinical_guidelines.pdf"): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device) self.doc_processor = DocumentProcessor() # Initialize or load FAISS index if os.path.exists(index_path): self.index = self.doc_processor.load_index(index_path) # Load chunks from a saved file (you might want to save/load chunks separately) self.chunks = [] # Load your chunks here else: self.chunks, self.index = self.doc_processor.process_document(pdf_path, index_path) def generate_response(self, query: str, retrieved_chunks: List[str]) -> str: """Generate response using the LLM.""" if not retrieved_chunks: return "This information is not available in the current guidelines." # Prepare context and prompt context = "\n".join(retrieved_chunks) prompt = f"""Based on the following clinical guidelines, answer the question. If the information is not explicitly stated in the guidelines, respond with "This information is not available in the current guidelines." Guidelines: {context} Question: {query} Answer:""" # Generate response inputs = self.tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(self.device) outputs = self.model.generate( **inputs, max_length=200, num_beams=4, temperature=0.7, top_p=0.9, do_sample=True ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Add disclaimer if response and response != "This information is not available in the current guidelines.": response += "\n\nDISCLAIMER: This information is for educational purposes only and not medical advice." return response def query(self, user_query: str) -> Dict[str, str]: """Process user query and return response with retrieved chunks.""" # Retrieve relevant chunks retrieved_chunks = self.doc_processor.retrieve_chunks( user_query, self.index, self.chunks ) # Generate response response = self.generate_response(user_query, retrieved_chunks) return { "response": response, "retrieved_chunks": retrieved_chunks }