thechaiexperiment commited on
Commit
059b54f
·
verified ·
1 Parent(s): 4ebeee2

Create medical_rag.py

Browse files
Files changed (1) hide show
  1. medical_rag.py +159 -0
medical_rag.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException
2
+ from pydantic import BaseModel
3
+ import nltk
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForTokenClassification,
7
+ pipeline
8
+ )
9
+ from typing import List, Dict, Optional
10
+ from general_rag import app, models, data, get_completion
11
+
12
+ # Initialize NLTK
13
+ nltk.download('punkt')
14
+
15
+ class MedicalProfile(BaseModel):
16
+ conditions: str
17
+ daily_symptoms: str
18
+ count: int
19
+
20
+ def load_medical_models():
21
+ try:
22
+ print("Loading medical domain models...")
23
+
24
+ # Medical-specific models (only NER, no LLM)
25
+ models['bio_tokenizer'] = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
26
+ models['bio_model'] = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
27
+ models['ner_pipeline'] = pipeline("ner", model=models['bio_model'], tokenizer=models['bio_tokenizer'])
28
+
29
+ print("Medical domain models loaded successfully")
30
+ return True
31
+ except Exception as e:
32
+ print(f"Error loading medical models: {e}")
33
+ return False
34
+
35
+ def extract_entities(text):
36
+ try:
37
+ ner_pipeline = models['ner_pipeline']
38
+ ner_results = ner_pipeline(text)
39
+ entities = {result['word'] for result in ner_results if result['entity'].startswith("B-")}
40
+ return list(entities)
41
+ except Exception as e:
42
+ print(f"Error extracting entities: {e}")
43
+ return []
44
+
45
+ def match_entities(query_entities, sentence_entities):
46
+ try:
47
+ query_set, sentence_set = set(query_entities), set(sentence_entities)
48
+ matches = query_set.intersection(sentence_set)
49
+ return len(matches)
50
+ except Exception as e:
51
+ print(f"Error matching entities: {e}")
52
+ return 0
53
+
54
+ def extract_relevant_portions(document_texts, query, max_portions=3, portion_size=1, min_query_words=2):
55
+ relevant_portions = {}
56
+ query_entities = extract_entities(query)
57
+ print(f"Extracted Query Entities: {query_entities}")
58
+
59
+ for doc_id, doc_text in enumerate(document_texts):
60
+ sentences = nltk.sent_tokenize(doc_text)
61
+ doc_relevant_portions = []
62
+ doc_entities = extract_entities(doc_text)
63
+ print(f"Document {doc_id} Entities: {doc_entities}")
64
+
65
+ for i, sentence in enumerate(sentences):
66
+ sentence_entities = extract_entities(sentence)
67
+ relevance_score = match_entities(query_entities, sentence_entities)
68
+ if relevance_score >= min_query_words:
69
+ start_idx = max(0, i - portion_size // 2)
70
+ end_idx = min(len(sentences), i + portion_size // 2 + 1)
71
+ portion = " ".join(sentences[start_idx:end_idx])
72
+ doc_relevant_portions.append(portion)
73
+ if len(doc_relevant_portions) >= max_portions:
74
+ break
75
+
76
+ if not doc_relevant_portions and len(doc_entities) > 0:
77
+ print(f"Fallback: Selecting sentences with most entities for Document {doc_id}")
78
+ sorted_sentences = sorted(sentences, key=lambda s: len(extract_entities(s)), reverse=True)
79
+ for fallback_sentence in sorted_sentences[:max_portions]:
80
+ doc_relevant_portions.append(fallback_sentence)
81
+
82
+ relevant_portions[f"Document_{doc_id}"] = doc_relevant_portions
83
+ return relevant_portions
84
+
85
+ def enhance_passage_with_entities(passage, entities):
86
+ return f"{passage}\n\nEntities: {', '.join(entities)}"
87
+
88
+ def create_medical_prompt(question, passage):
89
+ prompt = ("""
90
+ As a medical expert, you are required to answer the following question based only on the provided passage.
91
+ Do not include any information not present in the passage.
92
+ Your response should directly reflect the content of the passage.
93
+ Maintain accuracy and relevance to the provided information.
94
+ Provide a medically reliable answer in no more than 250 words.
95
+
96
+ Passage: {passage}
97
+
98
+ Question: {question}
99
+
100
+ Answer:
101
+ """)
102
+ return prompt.format(passage=passage, question=question)
103
+
104
+ @app.post("/api/chat")
105
+ async def chat_endpoint(chat_query: ChatQuery):
106
+ try:
107
+ query_text = chat_query.query
108
+ language_code = chat_query.language_code
109
+ if language_code == 0:
110
+ query_text = translate_text(query_text, 'ar_to_en')
111
+
112
+ # Generate embeddings and retrieve relevant documents
113
+ query_embedding = embed_query_text(query_text)
114
+ n_results = 5
115
+ embeddings_data = load_embeddings()
116
+ folder_path = 'downloaded_articles/downloaded_articles'
117
+ initial_results = query_embeddings(query_embedding, embeddings_data, n_results)
118
+ document_ids = [doc_id for doc_id, _ in initial_results]
119
+ document_texts = retrieve_document_texts(document_ids, folder_path)
120
+
121
+ # Rerank documents with cross-encoder
122
+ cross_encoder = models['cross_encoder']
123
+ scores = cross_encoder.predict([(query_text, doc) for doc in document_texts])
124
+ scored_documents = list(zip(scores, document_ids, document_texts))
125
+ scored_documents.sort(key=lambda x: x[0], reverse=True)
126
+
127
+ # Extract relevant portions from documents using medical-specific function
128
+ relevant_portions = extract_relevant_portions(document_texts, query_text)
129
+ flattened_relevant_portions = []
130
+ for doc_id, portions in relevant_portions.items():
131
+ flattened_relevant_portions.extend(portions)
132
+
133
+ combined_parts = " ".join(flattened_relevant_portions)
134
+ entities = extract_entities(query_text)
135
+ passage = enhance_passage_with_entities(combined_parts, entities)
136
+
137
+ # Create medical-specific prompt and get completion from DeepSeek
138
+ prompt = create_medical_prompt(query_text, passage)
139
+ answer = get_completion(prompt)
140
+
141
+ final_answer = answer.strip()
142
+ if language_code == 0:
143
+ final_answer = translate_text(final_answer, 'en_to_ar')
144
+
145
+ if not final_answer:
146
+ final_answer = "Sorry, I can't help with that."
147
+
148
+ return {
149
+ "response": f"I hope this answers your question: {final_answer}",
150
+ "success": True
151
+ }
152
+
153
+ except HTTPException as e:
154
+ raise e
155
+ except Exception as e:
156
+ raise HTTPException(status_code=500, detail=str(e))
157
+
158
+ # Initialize medical models when this module is imported
159
+ load_medical_models()