MDSS / app.py
khalil2233's picture
lama
e23d57a verified
raw
history blame
11.5 kB
from datasets import load_dataset
# Load dataset from Hugging Face
dataset = load_dataset("MedRAG/textbooks")
# Preview dataset
print(dataset)
import pandas as pd
# Convert to Pandas DataFrame
df = pd.DataFrame(dataset["train"])
# Display first rows
print(df.head())
# Check file format
print(df.dtypes)
import nltk
import shutil
# Supprimer les ressources existantes
nltk.data.path.append('/root/nltk_data') # Ajouter le chemin de nltk_data
nltk.data.clear_cache() # Effacer le cache des donnΓ©es
# RΓ©installer le package 'punkt'
nltk.download('all')
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from nltk.stem import WordNetLemmatizer
# Download necessary NLTK components
nltk.download("stopwords")
nltk.download("punkt")
nltk.download("wordnet")
nltk.download("omw-1.4")
# Load stopwords and lemmatizer
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
# Step 1: Preprocessing Function
def preprocess_text(text):
text = text.lower() # Convert to lowercase
text = re.sub(r"[^\w\s]", "", text) # Remove special characters
words = word_tokenize(text) # Tokenization
words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal
return " ".join(words)
# Apply preprocessing before chunking
dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])})
# Step 2: Chunking Function
def chunk_text(text, chunk_size=3):
sentences = sent_tokenize(text) # Split text into sentences
return [" ".join(sentences[i:i+chunk_size]) for i in range(0, len(sentences), chunk_size)]
# Apply chunking on the cleaned text
dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])})
from sentence_transformers import SentenceTransformer
# Load BioBERT or MiniLM for fast embedding
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
def generate_embedding(row):
embedding = embed_model.encode(row["chunks"], convert_to_tensor=False).tolist()
# Fix: Ensure embedding is a flat list, not nested
row["embedding"] = embedding[0] if isinstance(embedding, list) and len(embedding) == 1 else embedding
return row
dataset = dataset.map(generate_embedding)
import numpy as np
# Flatten embeddings (convert [[...]] β†’ [...])
valid_embeddings = [
np.array(row["embedding"]).flatten().tolist() # Ensure each embedding is 1D
for row in dataset["train"]
if isinstance(row["embedding"], list) and len(row["embedding"]) == 384
]
# Convert to NumPy array
embeddings_np = np.array(valid_embeddings, dtype=np.float32)
# Check shape
print("βœ… Fixed Embeddings Shape:", embeddings_np.shape) # Expected: (num_samples, 384)
import numpy as np
# Flatten embeddings (convert [[...]] β†’ [...])
valid_embeddings = [
np.array(row["embedding"]).flatten().tolist() # Ensure each embedding is 1D
for row in dataset["train"]
if isinstance(row["embedding"], list) and len(row["embedding"]) == 384
]
# Convert to NumPy array
embeddings_np = np.array(valid_embeddings, dtype=np.float32)
# Check shape
print("βœ… Fixed Embeddings Shape:", embeddings_np.shape) # Expected: (num_samples, 384)
import faiss
# Check if embeddings are 2D
if len(embeddings_np.shape) == 1:
embeddings_np = embeddings_np.reshape(1, -1) # Ensure it's (num_samples, embedding_dim)
# Check final shape
print("Fixed Embeddings Shape:", embeddings_np.shape)
# Create FAISS index
index = faiss.IndexFlatL2(embeddings_np.shape[1])
index.add(embeddings_np) # Add all embeddings
print("βœ… Embeddings successfully stored in FAISS!")
print("Total embeddings in FAISS:", index.ntotal)
FAISS_INDEX_PATH = "/content/faiss_medical.index" # Save in Colab's file system
# Save the FAISS index
faiss.write_index(index, FAISS_INDEX_PATH)
print(f"βœ… FAISS index successfully saved at: {FAISS_INDEX_PATH}")
# Load FAISS index from file
index = faiss.read_index(FAISS_INDEX_PATH)
print(f"βœ… FAISS index loaded from: {FAISS_INDEX_PATH}")
print(f"Total embeddings stored: {index.ntotal}")
print("πŸ” Available columns:", dataset.column_names) # Should include "chunks"
medical_texts = dataset["train"]["chunks"] # βœ… Correct way to access chunks
# Use the same text that will be encoded
print("πŸ” Dataset structure:", dataset)
print("πŸ” Available columns in train:", dataset["train"].column_names)
print("βœ… First 3 chunked texts:", dataset["train"]["chunks"][:3])
import json
id_to_text = {idx: text for idx, text in enumerate(medical_texts)}
with open("id_to_text.json", "w") as f:
json.dump(id_to_text, f)
import os
# βœ… Check if file exists
if os.path.exists("id_to_text.json"):
print("βœ… `id_to_text.json` exists!")
# βœ… Load the JSON file
with open("id_to_text.json", "r") as f:
id_to_text = json.load(f)
# βœ… Compare number of records
print(f"πŸ“Š Records in `id_to_text.json`: {len(id_to_text)}")
print(f"πŸ“Š Records in `medical_texts`: {len(medical_texts)}")
if len(id_to_text) == len(medical_texts):
print("βœ… JSON file contains the correct number of records!")
else:
print("❌ Mismatch! FAISS ID mapping and dataset size are different.")
else:
print("❌ `id_to_text.json` was not found! Make sure it was saved correctly.")
import random
# βœ… Pick 3 random FAISS IDs
sample_ids = random.sample(list(id_to_text.keys()), 3)
# βœ… Print their corresponding texts
for faiss_id in sample_ids:
print(f"FAISS ID {faiss_id} β†’ Text: {id_to_text[faiss_id][:100]}...") # Show only first 100 chars
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
# βœ… Load FAISS
FAISS_INDEX_PATH = "/content/faiss_medical.index"
index = faiss.read_index(FAISS_INDEX_PATH)
# βœ… Load Sentence Transformer model
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# βœ… Test a retrieval query
query = "What are the symptoms of pneumonia?"
query_embedding = embed_model.encode([query])
# βœ… Perform FAISS search
D, I = index.search(np.array(query_embedding).astype("float32"), 3) # Retrieve top 3 matches
# βœ… Print the FAISS results & compare with JSON mapping
print("πŸ” FAISS Search Results:", I[0])
print("πŸ“ FAISS Distances:", D[0])
# βœ… Load `id_to_text.json`
with open("id_to_text.json", "r") as f:
id_to_text = json.load(f)
id_to_text = {int(k): v for k, v in id_to_text.items()} # Ensure keys are integers
# βœ… Print the matching texts
for faiss_id in I[0]:
print(f"FAISS ID {faiss_id} β†’ Text: {id_to_text[faiss_id][:100]}...") # Show first 100 characters
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
import json
# βœ… Load FAISS index
FAISS_INDEX_PATH = "/content/faiss_medical.index"
index = faiss.read_index(FAISS_INDEX_PATH)
# βœ… Load embedding model
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
# βœ… Load FAISS ID β†’ Text Mapping
with open("id_to_text.json", "r") as f:
id_to_text = json.load(f)
# βœ… Convert JSON keys to integers (FAISS returns int IDs)
id_to_text = {int(k): v for k, v in id_to_text.items()}
def retrieve_medical_summary(query, k=3):
"""
Retrieve the most relevant medical literature from FAISS.
Args:
query (str): The medical question.
k (int, optional): Number of closest documents to retrieve. Defaults to 3.
Returns:
str: The most relevant retrieved medical documents.
"""
# Convert query to embedding
query_embedding = embed_model.encode([query])
# Perform FAISS search
D, I = index.search(np.array(query_embedding).astype("float32"), k)
# Retrieve the closest matching text using FAISS index IDs
retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]]
# βœ… Ensure all retrieved texts are strings (Flatten lists if needed)
retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs]
# βœ… Join multiple retrieved documents into one response
return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found."
# βœ… Example Test
query = "What are the symptoms of pneumonia?"
retrieved_summary = retrieve_medical_summary(query, k=3)
print("πŸ“– Retrieved Medical Summary:\n", retrieved_summary)
import os
from groq import Groq
# βœ… Store API Key in Environment Variable
os.environ["GROQ_API_KEY"] = "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" # Replace with your actual key
# βœ… Initialize Groq client correctly (Retrieve API key properly)
client = Groq(api_key=os.getenv("GROQ_API_KEY"))
def generate_medical_answer_groq(query, model="llama-3.3-70b-versatile", max_tokens=500, temperature=0.3):
"""
Generates a medical response using Groq's API with LLaMA 3.3-70B, after retrieving relevant literature from FAISS.
Args:
query (str): The patient's medical question.
model (str, optional): The model to use. Defaults to "llama-3.3-70b-versatile".
max_tokens (int, optional): Max number of tokens to generate. Defaults to 200.
temperature (float, optional): Sampling temperature (higher = more creative). Defaults to 0.7.
Returns:
str: The AI-generated medical advice.
"""
# βœ… Retrieve relevant medical literature from FAISS
retrieved_summary = retrieve_medical_summary(query)
print("\nπŸ” Retrieved Medical Text for Query:", query)
print(retrieved_summary, "\n")
if not retrieved_summary or retrieved_summary == "No relevant data found.":
return "No relevant medical data found. Please consult a healthcare professional."
try:
# βœ… Send request to Groq API
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": "You are an expert AI specializing in medical knowledge."},
{"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"}
],
max_tokens=max_tokens,
temperature=temperature
)
return response.choices[0].message.content.strip() # Ensure clean output
except Exception as e:
return f"Error generating response: {str(e)}"
# βœ… Example Usage
query = "What are the symptoms of pneumonia?"
print("🩺 AI-Generated Response:", generate_medical_answer_groq(query))
# Gradio Interface
def ask_medical_question(question):
return generate_medical_answer_groq(question)
# Create Gradio Interface
iface = gr.Interface(
fn=ask_medical_question,
inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."),
outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."),
title="Medical Question Answering System",
description="Ask any medical question, and the AI will provide an answer based on medical literature."
)
# Launch the Gradio app
iface.launch()