from langchain.memory import ConversationBufferWindowMemory from langchain.schema import HumanMessage, AIMessage from datetime import datetime import json import re class MedicalMemoryManager: def __init__(self, k=10): self.conversation_memory = ConversationBufferWindowMemory(k=k, return_messages=True) self.patient_context = { "symptoms": [], "medical_history": [], "medications": [], "allergies": [], "lifestyle_factors": [], "timeline": [], "severity_scores": {}, "session_start": datetime.now().isoformat() } def add_interaction(self, human_input, ai_response): self.conversation_memory.chat_memory.add_user_message(human_input) self.conversation_memory.chat_memory.add_ai_message(ai_response) self._extract_medical_info(human_input) def _extract_medical_info(self, user_input): user_lower = user_input.lower() symptom_keywords = ["pain", "ache", "hurt", "sore", "cough", "fever", "nausea", "headache", "dizzy", "tired", "fatigue", "vomit", "swollen", "rash", "itch", "burn", "cramp", "bleed", "shortness of breath"] for keyword in symptom_keywords: if keyword in user_lower and keyword not in [s.lower() for s in self.patient_context["symptoms"]]: self.patient_context["symptoms"].append(user_input) break time_keywords = ["days", "weeks", "months", "hours", "yesterday", "today", "started", "began"] if any(keyword in user_lower for keyword in time_keywords): self.patient_context["timeline"].append(user_input) severity_match = re.search(r'\b([1-9]|10)\b.*(?:pain|severity|scale)', user_lower) if severity_match: self.patient_context["severity_scores"][datetime.now().isoformat()] = severity_match.group(1) med_keywords = ["taking", "medication", "medicine", "pills", "prescribed", "drug"] if any(keyword in user_lower for keyword in med_keywords): self.patient_context["medications"].append(user_input) allergy_keywords = ["allergic", "allergy", "allergies", "reaction"] if any(keyword in user_lower for keyword in allergy_keywords): self.patient_context["allergies"].append(user_input) def get_memory_context(self): messages = self.conversation_memory.chat_memory.messages context = [] for msg in messages[-6:]: if isinstance(msg, HumanMessage): context.append(f"Patient: {msg.content}") elif isinstance(msg, AIMessage): context.append(f"Doctor: {msg.content}") return "\n".join(context) def get_patient_summary(self): summary = { "conversation_turns": len(self.conversation_memory.chat_memory.messages) // 2, "session_duration": datetime.now().isoformat(), "key_symptoms": self.patient_context["symptoms"][-3:], "timeline_info": self.patient_context["timeline"][-2:], "medications": self.patient_context["medications"], "allergies": self.patient_context["allergies"], "severity_scores": self.patient_context["severity_scores"] } return json.dumps(summary, indent=2) def reset_session(self): self.conversation_memory.clear() self.patient_context = { "symptoms": [], "medical_history": [], "medications": [], "allergies": [], "lifestyle_factors": [], "timeline": [], "severity_scores": {}, "session_start": datetime.now().isoformat() }