File size: 3,715 Bytes
031a3f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema import HumanMessage, AIMessage
from datetime import datetime
import json
import re

class MedicalMemoryManager:
    def __init__(self, k=10):
        self.conversation_memory = ConversationBufferWindowMemory(k=k, return_messages=True)
        self.patient_context = {
            "symptoms": [],
            "medical_history": [],
            "medications": [],
            "allergies": [],
            "lifestyle_factors": [],
            "timeline": [],
            "severity_scores": {},
            "session_start": datetime.now().isoformat()
        }

    def add_interaction(self, human_input, ai_response):
        self.conversation_memory.chat_memory.add_user_message(human_input)
        self.conversation_memory.chat_memory.add_ai_message(ai_response)
        self._extract_medical_info(human_input)

    def _extract_medical_info(self, user_input):
        user_lower = user_input.lower()
        symptom_keywords = ["pain", "ache", "hurt", "sore", "cough", "fever", "nausea", "headache", "dizzy", "tired", "fatigue", "vomit", "swollen", "rash", "itch", "burn", "cramp", "bleed", "shortness of breath"]
        for keyword in symptom_keywords:
            if keyword in user_lower and keyword not in [s.lower() for s in self.patient_context["symptoms"]]:
                self.patient_context["symptoms"].append(user_input)
                break
        time_keywords = ["days", "weeks", "months", "hours", "yesterday", "today", "started", "began"]
        if any(keyword in user_lower for keyword in time_keywords):
            self.patient_context["timeline"].append(user_input)
        severity_match = re.search(r'\b([1-9]|10)\b.*(?:pain|severity|scale)', user_lower)
        if severity_match:
            self.patient_context["severity_scores"][datetime.now().isoformat()] = severity_match.group(1)
        med_keywords = ["taking", "medication", "medicine", "pills", "prescribed", "drug"]
        if any(keyword in user_lower for keyword in med_keywords):
            self.patient_context["medications"].append(user_input)
        allergy_keywords = ["allergic", "allergy", "allergies", "reaction"]
        if any(keyword in user_lower for keyword in allergy_keywords):
            self.patient_context["allergies"].append(user_input)

    def get_memory_context(self):
        messages = self.conversation_memory.chat_memory.messages
        context = []
        for msg in messages[-6:]:
            if isinstance(msg, HumanMessage):
                context.append(f"Patient: {msg.content}")
            elif isinstance(msg, AIMessage):
                context.append(f"Doctor: {msg.content}")
        return "\n".join(context)

    def get_patient_summary(self):
        summary = {
            "conversation_turns": len(self.conversation_memory.chat_memory.messages) // 2,
            "session_duration": datetime.now().isoformat(),
            "key_symptoms": self.patient_context["symptoms"][-3:],
            "timeline_info": self.patient_context["timeline"][-2:],
            "medications": self.patient_context["medications"],
            "allergies": self.patient_context["allergies"],
            "severity_scores": self.patient_context["severity_scores"]
        }
        return json.dumps(summary, indent=2)

    def reset_session(self):
        self.conversation_memory.clear()
        self.patient_context = {
            "symptoms": [],
            "medical_history": [],
            "medications": [],
            "allergies": [],
            "lifestyle_factors": [],
            "timeline": [],
            "severity_scores": {},
            "session_start": datetime.now().isoformat()
        }