khalil2233 commited on
Commit
8d72b01
·
verified ·
1 Parent(s): a87a80c
Files changed (1) hide show
  1. app.py +95 -291
app.py CHANGED
@@ -1,43 +1,20 @@
 
 
 
 
 
1
  from datasets import load_dataset
2
-
3
- # Load dataset from Hugging Face
4
- dataset = load_dataset("MedRAG/textbooks")
5
-
6
- # Preview dataset
7
- print(dataset)
8
-
9
- import pandas as pd
10
-
11
- # Convert to Pandas DataFrame
12
- df = pd.DataFrame(dataset["train"])
13
-
14
- # Display first rows
15
- print(df.head())
16
-
17
- # Check file format
18
- print(df.dtypes)
19
-
20
  import nltk
21
- import shutil
22
-
23
- # Supprimer les ressources existantes
24
- nltk.data.path.append('/root/nltk_data') # Ajouter le chemin de nltk_data
25
- nltk.data.clear_cache() # Effacer le cache des données
26
-
27
-
28
- # Réinstaller le package 'punkt'
29
- nltk.download('all')
30
-
31
-
32
  import re
33
- import nltk
34
  from nltk.corpus import stopwords
35
  from nltk.tokenize import word_tokenize, sent_tokenize
36
  from nltk.stem import WordNetLemmatizer
37
 
38
- # Download necessary NLTK components
39
- nltk.download("stopwords")
40
  nltk.download("punkt")
 
41
  nltk.download("wordnet")
42
  nltk.download("omw-1.4")
43
 
@@ -45,294 +22,121 @@ nltk.download("omw-1.4")
45
  stop_words = set(stopwords.words("english"))
46
  lemmatizer = WordNetLemmatizer()
47
 
48
- # Step 1: Preprocessing Function
 
 
 
 
 
 
 
49
  def preprocess_text(text):
 
50
  text = text.lower() # Convert to lowercase
51
  text = re.sub(r"[^\w\s]", "", text) # Remove special characters
52
  words = word_tokenize(text) # Tokenization
53
  words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal
54
  return " ".join(words)
55
 
56
- # Apply preprocessing before chunking
57
- dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])})
58
-
59
- # Step 2: Chunking Function
60
  def chunk_text(text, chunk_size=3):
 
61
  sentences = sent_tokenize(text) # Split text into sentences
62
- return [" ".join(sentences[i:i+chunk_size]) for i in range(0, len(sentences), chunk_size)]
63
-
64
- # Apply chunking on the cleaned text
65
- dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])})
66
-
67
- from sentence_transformers import SentenceTransformer
68
-
69
- # Load BioBERT or MiniLM for fast embedding
70
- embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
71
-
72
- def generate_embedding(row):
73
- embedding = embed_model.encode(row["chunks"], convert_to_tensor=False).tolist()
74
 
75
- # Fix: Ensure embedding is a flat list, not nested
76
- row["embedding"] = embedding[0] if isinstance(embedding, list) and len(embedding) == 1 else embedding
77
- return row
78
-
79
- dataset = dataset.map(generate_embedding)
80
-
81
- import numpy as np
82
-
83
- # Flatten embeddings (convert [[...]] → [...])
84
- valid_embeddings = [
85
- np.array(row["embedding"]).flatten().tolist() # Ensure each embedding is 1D
86
- for row in dataset["train"]
87
- if isinstance(row["embedding"], list) and len(row["embedding"]) == 384
88
- ]
89
-
90
- # Convert to NumPy array
91
- embeddings_np = np.array(valid_embeddings, dtype=np.float32)
92
-
93
- # Check shape
94
- print("✅ Fixed Embeddings Shape:", embeddings_np.shape) # Expected: (num_samples, 384)
95
-
96
- import numpy as np
97
-
98
- # Flatten embeddings (convert [[...]] → [...])
99
- valid_embeddings = [
100
- np.array(row["embedding"]).flatten().tolist() # Ensure each embedding is 1D
101
- for row in dataset["train"]
102
- if isinstance(row["embedding"], list) and len(row["embedding"]) == 384
103
- ]
104
-
105
- # Convert to NumPy array
106
- embeddings_np = np.array(valid_embeddings, dtype=np.float32)
107
-
108
- # Check shape
109
- print("✅ Fixed Embeddings Shape:", embeddings_np.shape) # Expected: (num_samples, 384)
110
-
111
- import faiss
112
-
113
- # Check if embeddings are 2D
114
- if len(embeddings_np.shape) == 1:
115
- embeddings_np = embeddings_np.reshape(1, -1) # Ensure it's (num_samples, embedding_dim)
116
-
117
- # Check final shape
118
- print("Fixed Embeddings Shape:", embeddings_np.shape)
119
 
120
  # Create FAISS index
121
- index = faiss.IndexFlatL2(embeddings_np.shape[1])
122
- index.add(embeddings_np) # Add all embeddings
123
-
124
- print("✅ Embeddings successfully stored in FAISS!")
125
- print("Total embeddings in FAISS:", index.ntotal)
126
-
127
- FAISS_INDEX_PATH = "/content/faiss_medical.index" # Save in Colab's file system
128
-
129
- # Save the FAISS index
130
- faiss.write_index(index, FAISS_INDEX_PATH)
131
-
132
- print(f"✅ FAISS index successfully saved at: {FAISS_INDEX_PATH}")
133
-
134
- # Load FAISS index from file
135
- index = faiss.read_index(FAISS_INDEX_PATH)
136
-
137
- print(f"✅ FAISS index loaded from: {FAISS_INDEX_PATH}")
138
- print(f"Total embeddings stored: {index.ntotal}")
139
-
140
- print("🔍 Available columns:", dataset.column_names) # Should include "chunks"
141
-
142
- medical_texts = dataset["train"]["chunks"] # ✅ Correct way to access chunks
143
- # Use the same text that will be encoded
144
-
145
- print("🔍 Dataset structure:", dataset)
146
- print("🔍 Available columns in train:", dataset["train"].column_names)
147
- print("✅ First 3 chunked texts:", dataset["train"]["chunks"][:3])
148
-
149
- import json
150
- id_to_text = {idx: text for idx, text in enumerate(medical_texts)}
151
-
152
- with open("id_to_text.json", "w") as f:
153
- json.dump(id_to_text, f)
154
-
155
- import os
156
-
157
- # ✅ Check if file exists
158
- if os.path.exists("id_to_text.json"):
159
- print("✅ `id_to_text.json` exists!")
160
-
161
- # ✅ Load the JSON file
162
- with open("id_to_text.json", "r") as f:
163
- id_to_text = json.load(f)
164
-
165
- # ✅ Compare number of records
166
- print(f"📊 Records in `id_to_text.json`: {len(id_to_text)}")
167
- print(f"📊 Records in `medical_texts`: {len(medical_texts)}")
168
-
169
- if len(id_to_text) == len(medical_texts):
170
- print("✅ JSON file contains the correct number of records!")
171
- else:
172
- print("❌ Mismatch! FAISS ID mapping and dataset size are different.")
173
-
174
- else:
175
- print("❌ `id_to_text.json` was not found! Make sure it was saved correctly.")
176
-
177
- import random
178
-
179
- # ✅ Pick 3 random FAISS IDs
180
- sample_ids = random.sample(list(id_to_text.keys()), 3)
181
-
182
- # ✅ Print their corresponding texts
183
- for faiss_id in sample_ids:
184
- print(f"FAISS ID {faiss_id} → Text: {id_to_text[faiss_id][:100]}...") # Show only first 100 chars
185
-
186
- import faiss
187
- import numpy as np
188
- from sentence_transformers import SentenceTransformer
189
-
190
- # ✅ Load FAISS
191
- FAISS_INDEX_PATH = "/content/faiss_medical.index"
192
- index = faiss.read_index(FAISS_INDEX_PATH)
193
-
194
- # ✅ Load Sentence Transformer model
195
- embed_model = SentenceTransformer("all-MiniLM-L6-v2")
196
-
197
- # ✅ Test a retrieval query
198
- query = "What are the symptoms of pneumonia?"
199
- query_embedding = embed_model.encode([query])
200
-
201
- # ✅ Perform FAISS search
202
- D, I = index.search(np.array(query_embedding).astype("float32"), 3) # Retrieve top 3 matches
203
-
204
- # ✅ Print the FAISS results & compare with JSON mapping
205
- print("🔍 FAISS Search Results:", I[0])
206
- print("📏 FAISS Distances:", D[0])
207
-
208
- # ✅ Load `id_to_text.json`
209
- with open("id_to_text.json", "r") as f:
210
- id_to_text = json.load(f)
211
-
212
- id_to_text = {int(k): v for k, v in id_to_text.items()} # Ensure keys are integers
213
-
214
- # ✅ Print the matching texts
215
- for faiss_id in I[0]:
216
- print(f"FAISS ID {faiss_id} → Text: {id_to_text[faiss_id][:100]}...") # Show first 100 characters
217
-
218
- import faiss
219
- import numpy as np
220
- from sentence_transformers import SentenceTransformer
221
- import json
222
-
223
- # ✅ Load FAISS index
224
- FAISS_INDEX_PATH = "/content/faiss_medical.index"
225
- index = faiss.read_index(FAISS_INDEX_PATH)
226
-
227
- # ✅ Load embedding model
228
- embed_model = SentenceTransformer("all-MiniLM-L6-v2")
229
-
230
- # ✅ Load FAISS ID → Text Mapping
231
- with open("id_to_text.json", "r") as f:
232
- id_to_text = json.load(f)
233
-
234
- # ✅ Convert JSON keys to integers (FAISS returns int IDs)
235
- id_to_text = {int(k): v for k, v in id_to_text.items()}
236
-
237
- def retrieve_medical_summary(query, k=3):
238
- """
239
- Retrieve the most relevant medical literature from FAISS.
240
-
241
- Args:
242
- query (str): The medical question.
243
- k (int, optional): Number of closest documents to retrieve. Defaults to 3.
244
-
245
- Returns:
246
- str: The most relevant retrieved medical documents.
247
- """
248
- # Convert query to embedding
249
  query_embedding = embed_model.encode([query])
250
-
251
- # Perform FAISS search
252
  D, I = index.search(np.array(query_embedding).astype("float32"), k)
253
-
254
- # Retrieve the closest matching text using FAISS index IDs
255
  retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]]
256
-
257
- # ✅ Ensure all retrieved texts are strings (Flatten lists if needed)
258
  retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs]
259
-
260
- # ✅ Join multiple retrieved documents into one response
261
  return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found."
262
 
263
-
264
- # Example Test
265
- query = "What are the symptoms of pneumonia?"
266
- retrieved_summary = retrieve_medical_summary(query, k=3)
267
-
268
- print("📖 Retrieved Medical Summary:\n", retrieved_summary)
269
-
270
-
271
-
272
- import os
273
- from groq import Groq
274
-
275
- # ✅ Store API Key in Environment Variable
276
- os.environ["GROQ_API_KEY"] = "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" # Replace with your actual key
277
-
278
- # ✅ Initialize Groq client correctly (Retrieve API key properly)
279
- client = Groq(api_key=os.getenv("GROQ_API_KEY"))
280
-
281
- def generate_medical_answer_groq(query, model="llama-3.3-70b-versatile", max_tokens=500, temperature=0.3):
282
- """
283
- Generates a medical response using Groq's API with LLaMA 3.3-70B, after retrieving relevant literature from FAISS.
284
-
285
- Args:
286
- query (str): The patient's medical question.
287
- model (str, optional): The model to use. Defaults to "llama-3.3-70b-versatile".
288
- max_tokens (int, optional): Max number of tokens to generate. Defaults to 200.
289
- temperature (float, optional): Sampling temperature (higher = more creative). Defaults to 0.7.
290
-
291
- Returns:
292
- str: The AI-generated medical advice.
293
- """
294
-
295
- # ✅ Retrieve relevant medical literature from FAISS
296
- retrieved_summary = retrieve_medical_summary(query)
297
- print("\n🔍 Retrieved Medical Text for Query:", query)
298
- print(retrieved_summary, "\n")
299
-
300
  if not retrieved_summary or retrieved_summary == "No relevant data found.":
301
  return "No relevant medical data found. Please consult a healthcare professional."
302
 
 
303
  try:
304
- # ✅ Send request to Groq API
305
  response = client.chat.completions.create(
306
- model=model,
307
  messages=[
308
  {"role": "system", "content": "You are an expert AI specializing in medical knowledge."},
309
  {"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"}
310
  ],
311
- max_tokens=max_tokens,
312
- temperature=temperature
313
  )
314
-
315
- return response.choices[0].message.content.strip() # Ensure clean output
316
-
317
  except Exception as e:
318
  return f"Error generating response: {str(e)}"
319
 
320
- # Example Usage
321
- query = "What are the symptoms of pneumonia?"
322
- print("🩺 AI-Generated Response:", generate_medical_answer_groq(query))
323
-
324
- # Gradio Interface
325
  def ask_medical_question(question):
326
- return generate_medical_answer_groq(question)
327
-
328
- # Create Gradio Interface
329
- iface = gr.Interface(
330
- fn=ask_medical_question,
331
- inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."),
332
- outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."),
333
- title="Medical Question Answering System",
334
- description="Ask any medical question, and the AI will provide an answer based on medical literature."
335
- )
336
-
337
- # Launch the Gradio app
338
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import faiss
5
+ import gradio as gr
6
  from datasets import load_dataset
7
+ from sentence_transformers import SentenceTransformer
8
+ from groq import Groq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import nltk
 
 
 
 
 
 
 
 
 
 
 
10
  import re
 
11
  from nltk.corpus import stopwords
12
  from nltk.tokenize import word_tokenize, sent_tokenize
13
  from nltk.stem import WordNetLemmatizer
14
 
15
+ # Initialize NLTK
 
16
  nltk.download("punkt")
17
+ nltk.download("stopwords")
18
  nltk.download("wordnet")
19
  nltk.download("omw-1.4")
20
 
 
22
  stop_words = set(stopwords.words("english"))
23
  lemmatizer = WordNetLemmatizer()
24
 
25
+ # Load dataset
26
+ def load_and_preprocess_dataset():
27
+ """Load and preprocess the dataset."""
28
+ dataset = load_dataset("MedRAG/textbooks")
29
+ print("Dataset loaded successfully.")
30
+ return dataset
31
+
32
+ # Preprocessing function
33
  def preprocess_text(text):
34
+ """Preprocess text by lowercasing, removing special characters, and lemmatizing."""
35
  text = text.lower() # Convert to lowercase
36
  text = re.sub(r"[^\w\s]", "", text) # Remove special characters
37
  words = word_tokenize(text) # Tokenization
38
  words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal
39
  return " ".join(words)
40
 
41
+ # Chunking function
 
 
 
42
  def chunk_text(text, chunk_size=3):
43
+ """Split text into chunks of sentences."""
44
  sentences = sent_tokenize(text) # Split text into sentences
45
+ return [" ".join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)]
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Generate embeddings
48
+ def generate_embeddings(dataset):
49
+ """Generate embeddings for the dataset."""
50
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
51
+ dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])})
52
+ dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])})
53
+ dataset = dataset.map(lambda row: {"embedding": embed_model.encode(row["chunks"], convert_to_tensor=False).tolist()})
54
+ return dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  # Create FAISS index
57
+ def create_faiss_index(dataset):
58
+ """Create and save a FAISS index for the embeddings."""
59
+ embeddings_np = np.array([np.array(row["embedding"]).flatten().tolist() for row in dataset["train"]], dtype=np.float32)
60
+ index = faiss.IndexFlatL2(embeddings_np.shape[1])
61
+ index.add(embeddings_np)
62
+ faiss.write_index(index, "faiss_medical.index")
63
+ print("FAISS index created and saved.")
64
+
65
+ # Load FAISS index
66
+ def load_faiss_index():
67
+ """Load the FAISS index."""
68
+ index = faiss.read_index("faiss_medical.index")
69
+ print("FAISS index loaded.")
70
+ return index
71
+
72
+ # Retrieve medical summary
73
+ def retrieve_medical_summary(query, index, id_to_text, k=3):
74
+ """Retrieve the most relevant medical literature from FAISS."""
75
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  query_embedding = embed_model.encode([query])
 
 
77
  D, I = index.search(np.array(query_embedding).astype("float32"), k)
 
 
78
  retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]]
 
 
79
  retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs]
 
 
80
  return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found."
81
 
82
+ # Generate medical answer using Groq
83
+ def generate_medical_answer_groq(query, index, id_to_text):
84
+ """Generate a medical response using Groq's API."""
85
+ retrieved_summary = retrieve_medical_summary(query, index, id_to_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if not retrieved_summary or retrieved_summary == "No relevant data found.":
87
  return "No relevant medical data found. Please consult a healthcare professional."
88
 
89
+ client = Groq(api_key=os.getenv("GROQ_API_KEY"))
90
  try:
 
91
  response = client.chat.completions.create(
92
+ model="llama-3.3-70b-versatile",
93
  messages=[
94
  {"role": "system", "content": "You are an expert AI specializing in medical knowledge."},
95
  {"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"}
96
  ],
97
+ max_tokens=500,
98
+ temperature=0.3
99
  )
100
+ return response.choices[0].message.content.strip()
 
 
101
  except Exception as e:
102
  return f"Error generating response: {str(e)}"
103
 
104
+ # Gradio interface
 
 
 
 
105
  def ask_medical_question(question):
106
+ """Gradio interface for asking medical questions."""
107
+ return generate_medical_answer_groq(question, index, id_to_text)
108
+
109
+ # Main function
110
+ def main():
111
+ """Main function to set up the system."""
112
+ global index, id_to_text
113
+
114
+ # Load and preprocess dataset
115
+ dataset = load_and_preprocess_dataset()
116
+ dataset = generate_embeddings(dataset)
117
+
118
+ # Create FAISS index
119
+ create_faiss_index(dataset)
120
+
121
+ # Load FAISS index
122
+ index = load_faiss_index()
123
+
124
+ # Create ID to text mapping
125
+ medical_texts = dataset["train"]["chunks"]
126
+ id_to_text = {idx: text for idx, text in enumerate(medical_texts)}
127
+ with open("id_to_text.json", "w") as f:
128
+ json.dump(id_to_text, f)
129
+
130
+ # Launch Gradio app
131
+ iface = gr.Interface(
132
+ fn=ask_medical_question,
133
+ inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."),
134
+ outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."),
135
+ title="Medical Question Answering System",
136
+ description="Ask any medical question, and the AI will provide an answer based on medical literature."
137
+ )
138
+ iface.launch()
139
+
140
+ # Run the main function
141
+ if __name__ == "__main__":
142
+ main()