Spaces:
Running
on
Zero
Running
on
Zero
from langchain.chains import ConversationChain, LLMChain | |
from langchain.prompts import PromptTemplate | |
from langchain.llms import HuggingFacePipeline | |
from langchain.memory import ConversationBufferMemory | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
import torch | |
import gradio as gr | |
# Model configuration | |
LLAMA_MODEL = "meta-llama/Llama-2-7b-chat-hf" | |
MEDITRON_MODEL = "epfl-llm/meditron-7b" | |
# System prompts | |
SYSTEM_PROMPT = """You are a professional virtual doctor. Your goal is to collect detailed information about the user's health condition, symptoms, medical history, medications, lifestyle, and other relevant data. | |
Ask 1-2 follow-up questions at a time to gather more details about: | |
- Detailed description of symptoms | |
- Duration (when did it start?) | |
- Severity (scale of 1-10) | |
- Aggravating or alleviating factors | |
- Related symptoms | |
- Medical history | |
- Current medications and allergies | |
After collecting sufficient information (4-5 exchanges), summarize findings and suggest when they should seek professional care. Do NOT make specific diagnoses or recommend specific treatments. | |
Respond empathetically and clearly. Always be professional and thorough.""" | |
MEDITRON_PROMPT = """<|im_start|>system | |
You are a specialized medical assistant focusing ONLY on suggesting over-the-counter medicines and home remedies based on patient information. | |
Based on the following patient information, provide ONLY: | |
1. One specific over-the-counter medicine with proper adult dosing instructions | |
2. One practical home remedy that might help | |
3. Clear guidance on when to seek professional medical care | |
Be concise, practical, and focus only on general symptom relief. Do not diagnose. Include a disclaimer that you are not a licensed medical professional. | |
<|im_end|> | |
<|im_start|>user | |
Patient information: {patient_info} | |
<|im_end|> | |
<|im_start|>assistant | |
""" | |
print("Loading Llama-2 model...") | |
# Create LangChain wrapper for Llama-2 | |
llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL) | |
llama_model = AutoModelForCausalLM.from_pretrained( | |
LLAMA_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Create a pipeline for LangChain | |
llama_pipeline = pipeline( | |
"text-generation", | |
model=llama_model, | |
tokenizer=llama_tokenizer, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
llama_llm = HuggingFacePipeline(pipeline=llama_pipeline) | |
print("Llama-2 model loaded successfully!") | |
print("Loading Meditron model...") | |
meditron_tokenizer = AutoTokenizer.from_pretrained(MEDITRON_MODEL) | |
meditron_model = AutoModelForCausalLM.from_pretrained( | |
MEDITRON_MODEL, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
# Create a pipeline for Meditron | |
meditron_pipeline = pipeline( | |
"text-generation", | |
model=meditron_model, | |
tokenizer=meditron_tokenizer, | |
max_new_tokens=256, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
meditron_llm = HuggingFacePipeline(pipeline=meditron_pipeline) | |
print("Meditron model loaded successfully!") | |
# Create LangChain conversation with memory | |
memory = ConversationBufferMemory(return_messages=True) | |
conversation = ConversationChain( | |
llm=llama_llm, | |
memory=memory, | |
verbose=True | |
) | |
# Create a template for the Meditron model | |
meditron_template = PromptTemplate( | |
input_variables=["patient_info"], | |
template=MEDITRON_PROMPT | |
) | |
meditron_chain = LLMChain( | |
llm=meditron_llm, | |
prompt=meditron_template, | |
verbose=True | |
) | |
# Track conversation turns | |
conversation_turns = 0 | |
patient_data = [] | |
def generate_response(message, history): | |
global conversation_turns, patient_data | |
conversation_turns += 1 | |
# Store patient message | |
patient_data.append(message) | |
# Format the prompt with system instructions | |
if conversation_turns >= 4: | |
# Add summarization instruction after 4 turns | |
prompt = f"{SYSTEM_PROMPT}\n\nNow summarize what you've learned and suggest when professional care may be needed.\n\n{message}" | |
else: | |
prompt = f"{SYSTEM_PROMPT}\n\n{message}" | |
# Generate response using LangChain conversation | |
llama_response = conversation.predict(input=prompt) | |
# After 4 turns, add medicine suggestions from Meditron | |
if conversation_turns >= 4: | |
# Collect full patient conversation | |
full_patient_info = "\n".join(patient_data) + "\n\nSummary: " + llama_response | |
# Get medicine suggestions using LangChain | |
medicine_suggestions = meditron_chain.run(patient_info=full_patient_info) | |
# Format final response | |
final_response = ( | |
f"{llama_response}\n\n" | |
f"--- MEDICATION AND HOME CARE SUGGESTIONS ---\n\n" | |
f"{medicine_suggestions}" | |
) | |
return final_response | |
return llama_response | |
# Create the Gradio interface | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
title="Medical Assistant with Medicine Suggestions", | |
description="Tell me about your symptoms, and after gathering enough information, I'll suggest potential remedies.", | |
examples=[ | |
"I have a cough and my throat hurts", | |
"I've been having headaches for a week", | |
"My stomach has been hurting since yesterday" | |
], | |
theme="soft" | |
) | |
if __name__ == "__main__": | |
demo.launch() |