File size: 3,025 Bytes
ceea46c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
from utils import DocumentProcessor
import os
from typing import List, Dict

class HealthcareRAG:
    def __init__(self, 
                 model_name: str = "google/flan-t5-base",
                 index_path: str = "faiss_index.bin",
                 pdf_path: str = "clinical_guidelines.pdf"):
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
        
        self.doc_processor = DocumentProcessor()
        
        # Initialize or load FAISS index
        if os.path.exists(index_path):
            self.index = self.doc_processor.load_index(index_path)
            # Load chunks from a saved file (you might want to save/load chunks separately)
            self.chunks = []  # Load your chunks here
        else:
            self.chunks, self.index = self.doc_processor.process_document(pdf_path, index_path)
    
    def generate_response(self, query: str, retrieved_chunks: List[str]) -> str:
        """Generate response using the LLM."""
        if not retrieved_chunks:
            return "This information is not available in the current guidelines."
        
        # Prepare context and prompt
        context = "\n".join(retrieved_chunks)
        prompt = f"""Based on the following clinical guidelines, answer the question. 
        If the information is not explicitly stated in the guidelines, respond with "This information is not available in the current guidelines."
        
        Guidelines:
        {context}
        
        Question: {query}
        
        Answer:"""
        
        # Generate response
        inputs = self.tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True).to(self.device)
        outputs = self.model.generate(
            **inputs,
            max_length=200,
            num_beams=4,
            temperature=0.7,
            top_p=0.9,
            do_sample=True
        )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Add disclaimer
        if response and response != "This information is not available in the current guidelines.":
            response += "\n\nDISCLAIMER: This information is for educational purposes only and not medical advice."
        
        return response
    
    def query(self, user_query: str) -> Dict[str, str]:
        """Process user query and return response with retrieved chunks."""
        # Retrieve relevant chunks
        retrieved_chunks = self.doc_processor.retrieve_chunks(
            user_query, 
            self.index, 
            self.chunks
        )
        
        # Generate response
        response = self.generate_response(user_query, retrieved_chunks)
        
        return {
            "response": response,
            "retrieved_chunks": retrieved_chunks
        }