Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
# Load dataset from Hugging Face | |
dataset = load_dataset("MedRAG/textbooks") | |
# Preview dataset | |
print(dataset) | |
import pandas as pd | |
# Convert to Pandas DataFrame | |
df = pd.DataFrame(dataset["train"]) | |
# Display first rows | |
print(df.head()) | |
# Check file format | |
print(df.dtypes) | |
import nltk | |
import shutil | |
# Supprimer les ressources existantes | |
nltk.data.path.append('/root/nltk_data') # Ajouter le chemin de nltk_data | |
nltk.data.clear_cache() # Effacer le cache des donnΓ©es | |
# RΓ©installer le package 'punkt' | |
nltk.download('all') | |
import re | |
import nltk | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize, sent_tokenize | |
from nltk.stem import WordNetLemmatizer | |
# Download necessary NLTK components | |
nltk.download("stopwords") | |
nltk.download("punkt") | |
nltk.download("wordnet") | |
nltk.download("omw-1.4") | |
# Load stopwords and lemmatizer | |
stop_words = set(stopwords.words("english")) | |
lemmatizer = WordNetLemmatizer() | |
# Step 1: Preprocessing Function | |
def preprocess_text(text): | |
text = text.lower() # Convert to lowercase | |
text = re.sub(r"[^\w\s]", "", text) # Remove special characters | |
words = word_tokenize(text) # Tokenization | |
words = [lemmatizer.lemmatize(w) for w in words if w not in stop_words] # Lemmatization & stopword removal | |
return " ".join(words) | |
# Apply preprocessing before chunking | |
dataset = dataset.map(lambda row: {"cleaned_content": preprocess_text(row["content"])}) | |
# Step 2: Chunking Function | |
def chunk_text(text, chunk_size=3): | |
sentences = sent_tokenize(text) # Split text into sentences | |
return [" ".join(sentences[i:i+chunk_size]) for i in range(0, len(sentences), chunk_size)] | |
# Apply chunking on the cleaned text | |
dataset = dataset.map(lambda row: {"chunks": chunk_text(row["cleaned_content"])}) | |
from sentence_transformers import SentenceTransformer | |
# Load BioBERT or MiniLM for fast embedding | |
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
def generate_embedding(row): | |
embedding = embed_model.encode(row["chunks"], convert_to_tensor=False).tolist() | |
# Fix: Ensure embedding is a flat list, not nested | |
row["embedding"] = embedding[0] if isinstance(embedding, list) and len(embedding) == 1 else embedding | |
return row | |
dataset = dataset.map(generate_embedding) | |
import numpy as np | |
# Flatten embeddings (convert [[...]] β [...]) | |
valid_embeddings = [ | |
np.array(row["embedding"]).flatten().tolist() # Ensure each embedding is 1D | |
for row in dataset["train"] | |
if isinstance(row["embedding"], list) and len(row["embedding"]) == 384 | |
] | |
# Convert to NumPy array | |
embeddings_np = np.array(valid_embeddings, dtype=np.float32) | |
# Check shape | |
print("β Fixed Embeddings Shape:", embeddings_np.shape) # Expected: (num_samples, 384) | |
import numpy as np | |
# Flatten embeddings (convert [[...]] β [...]) | |
valid_embeddings = [ | |
np.array(row["embedding"]).flatten().tolist() # Ensure each embedding is 1D | |
for row in dataset["train"] | |
if isinstance(row["embedding"], list) and len(row["embedding"]) == 384 | |
] | |
# Convert to NumPy array | |
embeddings_np = np.array(valid_embeddings, dtype=np.float32) | |
# Check shape | |
print("β Fixed Embeddings Shape:", embeddings_np.shape) # Expected: (num_samples, 384) | |
import faiss | |
# Check if embeddings are 2D | |
if len(embeddings_np.shape) == 1: | |
embeddings_np = embeddings_np.reshape(1, -1) # Ensure it's (num_samples, embedding_dim) | |
# Check final shape | |
print("Fixed Embeddings Shape:", embeddings_np.shape) | |
# Create FAISS index | |
index = faiss.IndexFlatL2(embeddings_np.shape[1]) | |
index.add(embeddings_np) # Add all embeddings | |
print("β Embeddings successfully stored in FAISS!") | |
print("Total embeddings in FAISS:", index.ntotal) | |
FAISS_INDEX_PATH = "/content/faiss_medical.index" # Save in Colab's file system | |
# Save the FAISS index | |
faiss.write_index(index, FAISS_INDEX_PATH) | |
print(f"β FAISS index successfully saved at: {FAISS_INDEX_PATH}") | |
# Load FAISS index from file | |
index = faiss.read_index(FAISS_INDEX_PATH) | |
print(f"β FAISS index loaded from: {FAISS_INDEX_PATH}") | |
print(f"Total embeddings stored: {index.ntotal}") | |
print("π Available columns:", dataset.column_names) # Should include "chunks" | |
medical_texts = dataset["train"]["chunks"] # β Correct way to access chunks | |
# Use the same text that will be encoded | |
print("π Dataset structure:", dataset) | |
print("π Available columns in train:", dataset["train"].column_names) | |
print("β First 3 chunked texts:", dataset["train"]["chunks"][:3]) | |
import json | |
id_to_text = {idx: text for idx, text in enumerate(medical_texts)} | |
with open("id_to_text.json", "w") as f: | |
json.dump(id_to_text, f) | |
import os | |
# β Check if file exists | |
if os.path.exists("id_to_text.json"): | |
print("β `id_to_text.json` exists!") | |
# β Load the JSON file | |
with open("id_to_text.json", "r") as f: | |
id_to_text = json.load(f) | |
# β Compare number of records | |
print(f"π Records in `id_to_text.json`: {len(id_to_text)}") | |
print(f"π Records in `medical_texts`: {len(medical_texts)}") | |
if len(id_to_text) == len(medical_texts): | |
print("β JSON file contains the correct number of records!") | |
else: | |
print("β Mismatch! FAISS ID mapping and dataset size are different.") | |
else: | |
print("β `id_to_text.json` was not found! Make sure it was saved correctly.") | |
import random | |
# β Pick 3 random FAISS IDs | |
sample_ids = random.sample(list(id_to_text.keys()), 3) | |
# β Print their corresponding texts | |
for faiss_id in sample_ids: | |
print(f"FAISS ID {faiss_id} β Text: {id_to_text[faiss_id][:100]}...") # Show only first 100 chars | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
# β Load FAISS | |
FAISS_INDEX_PATH = "/content/faiss_medical.index" | |
index = faiss.read_index(FAISS_INDEX_PATH) | |
# β Load Sentence Transformer model | |
embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
# β Test a retrieval query | |
query = "What are the symptoms of pneumonia?" | |
query_embedding = embed_model.encode([query]) | |
# β Perform FAISS search | |
D, I = index.search(np.array(query_embedding).astype("float32"), 3) # Retrieve top 3 matches | |
# β Print the FAISS results & compare with JSON mapping | |
print("π FAISS Search Results:", I[0]) | |
print("π FAISS Distances:", D[0]) | |
# β Load `id_to_text.json` | |
with open("id_to_text.json", "r") as f: | |
id_to_text = json.load(f) | |
id_to_text = {int(k): v for k, v in id_to_text.items()} # Ensure keys are integers | |
# β Print the matching texts | |
for faiss_id in I[0]: | |
print(f"FAISS ID {faiss_id} β Text: {id_to_text[faiss_id][:100]}...") # Show first 100 characters | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
import json | |
# β Load FAISS index | |
FAISS_INDEX_PATH = "/content/faiss_medical.index" | |
index = faiss.read_index(FAISS_INDEX_PATH) | |
# β Load embedding model | |
embed_model = SentenceTransformer("all-MiniLM-L6-v2") | |
# β Load FAISS ID β Text Mapping | |
with open("id_to_text.json", "r") as f: | |
id_to_text = json.load(f) | |
# β Convert JSON keys to integers (FAISS returns int IDs) | |
id_to_text = {int(k): v for k, v in id_to_text.items()} | |
def retrieve_medical_summary(query, k=3): | |
""" | |
Retrieve the most relevant medical literature from FAISS. | |
Args: | |
query (str): The medical question. | |
k (int, optional): Number of closest documents to retrieve. Defaults to 3. | |
Returns: | |
str: The most relevant retrieved medical documents. | |
""" | |
# Convert query to embedding | |
query_embedding = embed_model.encode([query]) | |
# Perform FAISS search | |
D, I = index.search(np.array(query_embedding).astype("float32"), k) | |
# Retrieve the closest matching text using FAISS index IDs | |
retrieved_docs = [id_to_text.get(int(idx), "No relevant data found.") for idx in I[0]] | |
# β Ensure all retrieved texts are strings (Flatten lists if needed) | |
retrieved_docs = [doc if isinstance(doc, str) else " ".join(doc) for doc in retrieved_docs] | |
# β Join multiple retrieved documents into one response | |
return "\n\n---\n\n".join(retrieved_docs) if retrieved_docs else "No relevant data found." | |
# β Example Test | |
query = "What are the symptoms of pneumonia?" | |
retrieved_summary = retrieve_medical_summary(query, k=3) | |
print("π Retrieved Medical Summary:\n", retrieved_summary) | |
import os | |
from groq import Groq | |
# β Store API Key in Environment Variable | |
os.environ["GROQ_API_KEY"] = "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" # Replace with your actual key | |
# β Initialize Groq client correctly (Retrieve API key properly) | |
client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
def generate_medical_answer_groq(query, model="llama-3.3-70b-versatile", max_tokens=500, temperature=0.3): | |
""" | |
Generates a medical response using Groq's API with LLaMA 3.3-70B, after retrieving relevant literature from FAISS. | |
Args: | |
query (str): The patient's medical question. | |
model (str, optional): The model to use. Defaults to "llama-3.3-70b-versatile". | |
max_tokens (int, optional): Max number of tokens to generate. Defaults to 200. | |
temperature (float, optional): Sampling temperature (higher = more creative). Defaults to 0.7. | |
Returns: | |
str: The AI-generated medical advice. | |
""" | |
# β Retrieve relevant medical literature from FAISS | |
retrieved_summary = retrieve_medical_summary(query) | |
print("\nπ Retrieved Medical Text for Query:", query) | |
print(retrieved_summary, "\n") | |
if not retrieved_summary or retrieved_summary == "No relevant data found.": | |
return "No relevant medical data found. Please consult a healthcare professional." | |
try: | |
# β Send request to Groq API | |
response = client.chat.completions.create( | |
model=model, | |
messages=[ | |
{"role": "system", "content": "You are an expert AI specializing in medical knowledge."}, | |
{"role": "user", "content": f"Summarize the following medical literature and provide a structured medical answer:\n\n### Medical Literature ###\n{retrieved_summary}\n\n### Patient Question ###\n{query}\n\n### Medical Advice ###"} | |
], | |
max_tokens=max_tokens, | |
temperature=temperature | |
) | |
return response.choices[0].message.content.strip() # Ensure clean output | |
except Exception as e: | |
return f"Error generating response: {str(e)}" | |
# β Example Usage | |
query = "What are the symptoms of pneumonia?" | |
print("π©Ί AI-Generated Response:", generate_medical_answer_groq(query)) | |
# Gradio Interface | |
def ask_medical_question(question): | |
return generate_medical_answer_groq(question) | |
# Create Gradio Interface | |
iface = gr.Interface( | |
fn=ask_medical_question, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your medical question here..."), | |
outputs=gr.Textbox(lines=10, placeholder="AI-generated medical advice will appear here..."), | |
title="Medical Question Answering System", | |
description="Ask any medical question, and the AI will provide an answer based on medical literature." | |
) | |
# Launch the Gradio app | |
iface.launch() |