|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
import torch |
|
from utils import DocumentProcessor |
|
import os |
|
from typing import List, Dict |
|
|
|
class HealthcareRAG: |
|
def __init__(self, |
|
model_name: str = "google/flan-t5-base", |
|
index_path: str = "faiss_index.bin", |
|
pdf_path: str = "clinical_guidelines.pdf"): |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device) |
|
|
|
self.doc_processor = DocumentProcessor() |
|
|
|
|
|
if os.path.exists(index_path): |
|
self.index = self.doc_processor.load_index(index_path) |
|
|
|
self.chunks = [] |
|
else: |
|
self.chunks, self.index = self.doc_processor.process_document(pdf_path, index_path) |
|
|
|
def generate_response(self, query: str, retrieved_chunks: List[str]) -> str: |
|
"""Generate response using the LLM.""" |
|
if not retrieved_chunks: |
|
return "This information is not available in the current guidelines." |
|
|
|
|
|
context = "\n".join(retrieved_chunks) |
|
prompt = f"""Based on the following clinical guidelines, answer the question. |
|
If the information is not explicitly stated in the guidelines, respond with "This information is not available in the current guidelines." |
|
|
|
Guidelines: |
|
{context} |
|
|
|
Question: {query} |
|
|
|
Answer:""" |
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(self.device) |
|
outputs = self.model.generate( |
|
**inputs, |
|
max_length=200, |
|
num_beams=4, |
|
temperature=0.7, |
|
top_p=0.9, |
|
do_sample=True |
|
) |
|
|
|
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if response and response != "This information is not available in the current guidelines.": |
|
response += "\n\nDISCLAIMER: This information is for educational purposes only and not medical advice." |
|
|
|
return response |
|
|
|
def query(self, user_query: str) -> Dict[str, str]: |
|
"""Process user query and return response with retrieved chunks.""" |
|
|
|
retrieved_chunks = self.doc_processor.retrieve_chunks( |
|
user_query, |
|
self.index, |
|
self.chunks |
|
) |
|
|
|
|
|
response = self.generate_response(user_query, retrieved_chunks) |
|
|
|
return { |
|
"response": response, |
|
"retrieved_chunks": retrieved_chunks |
|
} |