File size: 3,561 Bytes
6e15dcc
6daaaf3
031a3f5
 
 
 
6daaaf3
 
 
031a3f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6e15dcc
031a3f5
 
 
6daaaf3
 
 
031a3f5
6daaaf3
 
031a3f5
 
 
 
6daaaf3
031a3f5
 
 
 
 
 
 
6daaaf3
 
 
031a3f5
 
6daaaf3
 
 
031a3f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6daaaf3
031a3f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import spaces
import logging
from .model import ModelManager
from .memory import MedicalMemoryManager
from .prompts import CONSULTATION_PROMPT, MEDICINE_PROMPT

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')

model_manager = ModelManager()
memory_manager = MedicalMemoryManager()
conversation_turns = 0


def build_me_llama_prompt(system_prompt, history, user_input):
    memory_context = memory_manager.get_memory_context()
    enhanced_system_prompt = f"{system_prompt}\n\nPrevious conversation context:\n{memory_context}"
    prompt = f"<s>[INST] <<SYS>>\n{enhanced_system_prompt}\n<</SYS>>\n\n"
    recent_history = history[-3:] if len(history) > 3 else history
    for user_msg, assistant_msg in recent_history:
        prompt += f"{user_msg} [/INST] {assistant_msg} </s><s>[INST] "
    prompt += f"{user_input} [/INST] "
    return prompt

@spaces.GPU
def respond(message, chat_history):
    global conversation_turns
    conversation_turns += 1
    logging.info(f"User input: {message}")
    if conversation_turns < 4:
        logging.info("Using CONSULTATION_PROMPT for information gathering.")
        prompt = build_me_llama_prompt(CONSULTATION_PROMPT, chat_history, message)
        response = model_manager.generate(prompt, max_new_tokens=128)
        logging.info(f"Model response: {response}")
        memory_manager.add_interaction(message, response)
        chat_history.append((message, response))
        return "", chat_history
    else:
        logging.info("Using CONSULTATION_PROMPT for summary and MEDICINE_PROMPT for suggestions.")
        patient_summary = memory_manager.get_patient_summary()
        memory_context = memory_manager.get_memory_context()
        summary_prompt = build_me_llama_prompt(
            CONSULTATION_PROMPT + "\n\nNow provide a comprehensive summary based on all the information gathered. Include when professional care may be needed.",
            chat_history,
            message
        )
        logging.info("Generating summary with CONSULTATION_PROMPT.")
        summary = model_manager.generate(summary_prompt, max_new_tokens=400)
        logging.info(f"Summary response: {summary}")
        full_patient_info = f"Patient Summary: {patient_summary}\n\nDetailed Summary: {summary}"
        med_prompt = f"<s>[INST] {MEDICINE_PROMPT.format(patient_info=full_patient_info, memory_context=memory_context)} [/INST] "
        logging.info("Generating medicine suggestions with MEDICINE_PROMPT.")
        medicine_suggestions = model_manager.generate(med_prompt, max_new_tokens=200)
        logging.info(f"Medicine suggestions: {medicine_suggestions}")
        final_response = (
            f"**COMPREHENSIVE MEDICAL SUMMARY:**\n{summary}\n\n"
            f"**MEDICATION AND HOME CARE SUGGESTIONS:**\n{medicine_suggestions}\n\n"
            f"**PATIENT CONTEXT SUMMARY:**\n{patient_summary}\n\n"
            f"**DISCLAIMER:** This is AI-generated advice for informational purposes only. Please consult a licensed healthcare provider for proper medical diagnosis and treatment."
        )
        memory_manager.add_interaction(message, final_response)
        chat_history.append((message, final_response))
        return "", chat_history

def reset_chat():
    global conversation_turns
    conversation_turns = 0
    memory_manager.reset_session()
    reset_msg = "New consultation started. Please tell me about your symptoms or health concerns."
    logging.info("Session reset. New consultation started.")
    return [(None, reset_msg)], ""