Spaces:
Running
Running
Commit
·
b8f3012
1
Parent(s):
691b688
Update the model to work as agent and charbot
Browse files
main.py
CHANGED
@@ -4,7 +4,7 @@ import requests
|
|
4 |
import traceback
|
5 |
import time
|
6 |
import os
|
7 |
-
from typing import Dict, Any, List, Optional
|
8 |
from datetime import datetime, timedelta
|
9 |
|
10 |
# Updated imports for pydantic
|
@@ -27,10 +27,34 @@ import numpy as np
|
|
27 |
from endpoints_documentation import endpoints_documentation
|
28 |
|
29 |
# Set environment variables for HuggingFace
|
30 |
-
# os.
|
|
|
|
|
|
|
31 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
32 |
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
class EndpointRequest(BaseModel):
|
35 |
"""Data model for API endpoint requests"""
|
36 |
endpoint: str = Field(..., description="The API endpoint path to call")
|
@@ -39,188 +63,147 @@ class EndpointRequest(BaseModel):
|
|
39 |
missing_required: List[str] = Field(default_factory=list, description="Any required parameters that are missing")
|
40 |
|
41 |
|
42 |
-
class
|
43 |
def __init__(self):
|
44 |
self.endpoints_documentation = endpoints_documentation
|
45 |
-
self.ollama_base_url = "http://localhost:11434"
|
46 |
-
|
47 |
-
self.
|
48 |
-
self.
|
49 |
-
self.
|
50 |
-
'Content-type': 'application/json'
|
51 |
-
}
|
52 |
-
self.user_id = '76ceed74-143a-45c1-843e-ba583c122dea'
|
53 |
self.max_retries = 3
|
54 |
-
self.retry_delay = 2
|
55 |
|
56 |
-
#
|
57 |
-
self.
|
|
|
58 |
|
59 |
-
# Initialize
|
|
|
60 |
self._initialize_llm()
|
61 |
self._initialize_parsers_and_chains()
|
62 |
-
|
63 |
-
# Add date parsing capabilities
|
64 |
self._initialize_date_parser()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def _initialize_language_tools(self):
|
67 |
-
"""Initialize
|
68 |
-
# Use multilingual embeddings for semantic understanding
|
69 |
-
self.embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
|
70 |
-
|
71 |
-
# Initialize language identification model
|
72 |
try:
|
|
|
73 |
self.language_classifier = pipeline(
|
74 |
"text-classification",
|
75 |
model="papluca/xlm-roberta-base-language-detection",
|
76 |
top_k=1
|
77 |
)
|
78 |
-
print("Language classification model loaded successfully")
|
79 |
-
except Exception as e:
|
80 |
-
print(f"Failed to load language classification model: {e}")
|
81 |
-
# Fallback to basic regex detection if model fails to load
|
82 |
-
self.language_classifier = None
|
83 |
-
|
84 |
-
# Add sentiment analysis for enhanced response generation
|
85 |
-
try:
|
86 |
self.sentiment_analyzer = pipeline(
|
87 |
"sentiment-analysis",
|
88 |
model="cardiffnlp/twitter-xlm-roberta-base-sentiment"
|
89 |
)
|
90 |
-
print("
|
91 |
except Exception as e:
|
92 |
-
print(f"
|
|
|
93 |
self.sentiment_analyzer = None
|
94 |
|
95 |
def _initialize_date_parser(self):
|
96 |
-
"""Initialize date parsing model
|
97 |
try:
|
98 |
self.date_parser = pipeline(
|
99 |
"token-classification",
|
100 |
model="Jean-Baptiste/roberta-large-ner-english",
|
101 |
aggregation_strategy="simple"
|
102 |
)
|
103 |
-
print("Date parsing model loaded successfully")
|
104 |
except Exception as e:
|
105 |
-
print(f"
|
106 |
self.date_parser = None
|
107 |
|
108 |
-
def detect_language(self, text):
|
109 |
-
"""
|
110 |
-
Enhanced language detection using HuggingFace models
|
111 |
-
"""
|
112 |
-
# First try using the HuggingFace language classification model if available
|
113 |
-
if self.language_classifier and len(text.strip()) > 3:
|
114 |
-
try:
|
115 |
-
result = self.language_classifier(text)
|
116 |
-
detected_lang = result[0][0]['label']
|
117 |
-
confidence = result[0][0]['score']
|
118 |
-
|
119 |
-
print(f"Language detected: {detected_lang} with confidence {confidence:.4f}")
|
120 |
-
|
121 |
-
# Map the detected language to our simplified language set
|
122 |
-
if detected_lang in ['ar', 'arabic']:
|
123 |
-
return "arabic"
|
124 |
-
elif detected_lang in ['en', 'english']:
|
125 |
-
return "english"
|
126 |
-
elif confidence > 0.8: # If confident but not English/Arabic
|
127 |
-
# We currently only support English/Arabic, but log other languages
|
128 |
-
print(f"Detected unsupported language: {detected_lang}")
|
129 |
-
# Default to English for other languages for now
|
130 |
-
return "english"
|
131 |
-
except Exception as e:
|
132 |
-
print(f"Error in language detection model: {e}")
|
133 |
-
# Continue to fallback methods
|
134 |
-
|
135 |
-
# Fallback: Basic detection of Arabic text using regex
|
136 |
-
arabic_pattern = re.compile(r'[\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF]+')
|
137 |
-
if arabic_pattern.search(text):
|
138 |
-
return "arabic"
|
139 |
-
|
140 |
-
# Default to English
|
141 |
-
return "english"
|
142 |
-
|
143 |
-
def analyze_sentiment(self, text):
|
144 |
-
"""Analyze the sentiment of the input text"""
|
145 |
-
if self.sentiment_analyzer and len(text.strip()) > 3:
|
146 |
-
try:
|
147 |
-
result = self.sentiment_analyzer(text)
|
148 |
-
sentiment = result[0]['label']
|
149 |
-
score = result[0]['score']
|
150 |
-
return {
|
151 |
-
"sentiment": sentiment,
|
152 |
-
"score": score
|
153 |
-
}
|
154 |
-
except Exception as e:
|
155 |
-
print(f"Error in sentiment analysis: {e}")
|
156 |
-
|
157 |
-
# Default neutral sentiment if analysis fails
|
158 |
-
return {"sentiment": "NEUTRAL", "score": 0.5}
|
159 |
-
|
160 |
-
def extract_semantic_keywords(self, text, top_n=5):
|
161 |
-
"""Extract semantic keywords from text using embeddings"""
|
162 |
-
try:
|
163 |
-
# Simple keyword extraction using embeddings comparison
|
164 |
-
# This is a basic implementation - could be enhanced further
|
165 |
-
words = re.findall(r'\b\w+\b', text.lower())
|
166 |
-
unique_words = list(set([w for w in words if len(w) > 3]))
|
167 |
-
|
168 |
-
if not unique_words:
|
169 |
-
return []
|
170 |
-
|
171 |
-
# Get embeddings for all words
|
172 |
-
embeddings_list = []
|
173 |
-
for word in unique_words:
|
174 |
-
try:
|
175 |
-
emb = self.embeddings.embed_query(word)
|
176 |
-
embeddings_list.append((word, emb))
|
177 |
-
except Exception as e:
|
178 |
-
print(f"Error embedding word {word}: {e}")
|
179 |
-
|
180 |
-
# Get embedding for full text
|
181 |
-
text_embedding = self.embeddings.embed_query(text)
|
182 |
-
|
183 |
-
# Calculate similarity to full text
|
184 |
-
similarities = []
|
185 |
-
for word, emb in embeddings_list:
|
186 |
-
similarity = np.dot(emb, text_embedding) / (np.linalg.norm(emb) * np.linalg.norm(text_embedding))
|
187 |
-
similarities.append((word, similarity))
|
188 |
-
|
189 |
-
# Sort by similarity
|
190 |
-
similarities.sort(key=lambda x: x[1], reverse=True)
|
191 |
-
|
192 |
-
# Return top N keywords
|
193 |
-
return [word for word, _ in similarities[:top_n]]
|
194 |
-
|
195 |
-
except Exception as e:
|
196 |
-
print(f"Error extracting keywords: {e}")
|
197 |
-
return []
|
198 |
-
|
199 |
def _initialize_llm(self):
|
200 |
-
"""Initialize the LLM
|
201 |
-
# Set up the callback manager for streaming (optional)
|
202 |
callbacks = [StreamingStdOutCallbackHandler()]
|
203 |
-
|
204 |
-
# Initialize the Ollama LLM with updated parameters
|
205 |
self.llm = OllamaLLM(
|
206 |
model=self.model_name,
|
207 |
base_url=self.ollama_base_url,
|
208 |
callbacks=callbacks,
|
209 |
temperature=0.7,
|
210 |
-
num_ctx=8192,
|
211 |
top_p=0.9,
|
212 |
-
request_timeout=60,
|
213 |
)
|
214 |
|
215 |
def _initialize_parsers_and_chains(self):
|
216 |
-
"""Initialize
|
217 |
-
# Setup JSON parser for structured output
|
218 |
self.json_parser = JsonOutputParser(pydantic_object=EndpointRequest)
|
219 |
|
220 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
self.router_prompt_template = PromptTemplate(
|
222 |
template="""
|
223 |
-
|
224 |
|
225 |
=== ENDPOINT DOCUMENTATION ===
|
226 |
{endpoints_documentation}
|
@@ -292,66 +275,56 @@ class AIAgent:
|
|
292 |
- User asks "what appointments today" → Use appointment listing with date filter
|
293 |
- User wants to "update medication" → Use medication update endpoint with patient_id
|
294 |
|
295 |
-
Think step by step and be precise with your endpoint selection and parameter extraction
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
#
|
345 |
-
# - Include only what's necessary to answer the query
|
346 |
-
# - Maintain professional and polite tone
|
347 |
-
# - Use proper healthcare terminology
|
348 |
-
# """,
|
349 |
-
# input_variables=["user_query", "api_response", "detected_language",
|
350 |
-
# "sentiment_analysis", "extracted_keywords"]
|
351 |
-
# )
|
352 |
-
# Create user-friendly response template with enhanced context awareness
|
353 |
-
# Create user-friendly response template with enhanced context awareness
|
354 |
-
# Create user-friendly response template with enhanced context awareness
|
355 |
self.user_response_template = PromptTemplate(
|
356 |
template="""
|
357 |
You are a professional healthcare assistant. Generate clear, accurate responses using EXACT data from the system.
|
@@ -438,208 +411,159 @@ class AIAgent:
|
|
438 |
=== FINAL INSTRUCTION ===
|
439 |
Respond ONLY in the requested language. Do NOT provide translations, explanations, or additional text in any other language. Stop immediately after answering the user's question.
|
440 |
""",
|
441 |
-
input_variables=["user_query", "api_response", "detected_language",
|
442 |
-
"sentiment_analysis", "extracted_keywords"]
|
443 |
-
)
|
444 |
-
|
445 |
-
# Create LLM chains
|
446 |
-
self.router_chain = LLMChain(
|
447 |
-
llm=self.llm,
|
448 |
-
prompt=self.router_prompt_template,
|
449 |
-
output_key="route_result"
|
450 |
-
)
|
451 |
-
|
452 |
-
self.user_response_chain = LLMChain(
|
453 |
-
llm=self.llm,
|
454 |
-
prompt=self.user_response_template,
|
455 |
-
output_key="user_friendly_response"
|
456 |
)
|
457 |
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
463 |
|
464 |
-
#
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
}
|
469 |
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
474 |
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
|
480 |
-
|
481 |
-
for
|
482 |
-
|
483 |
-
|
484 |
|
485 |
-
|
486 |
-
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
return (today + timedelta(days=1)).strftime('%Y-%m-%dT%H:%M:%S')
|
496 |
-
except Exception as e:
|
497 |
-
print(f"Error in date parsing: {e}")
|
498 |
|
499 |
-
#
|
500 |
-
|
|
|
501 |
|
502 |
-
def
|
503 |
-
"""
|
504 |
-
Process the user query through the LangChain pipeline and return a response
|
505 |
-
"""
|
506 |
try:
|
507 |
-
|
508 |
-
|
509 |
-
# Detect language of the query
|
510 |
-
detected_language = self.detect_language(user_query)
|
511 |
-
print(f"Detected language: {detected_language}")
|
512 |
-
|
513 |
-
# Enhanced context using Hugging Face models
|
514 |
-
sentiment_result = self.analyze_sentiment(user_query)
|
515 |
-
print(f"Sentiment analysis: {sentiment_result}")
|
516 |
-
|
517 |
-
extracted_keywords = self.extract_semantic_keywords(user_query)
|
518 |
-
print(f"Extracted keywords: {extracted_keywords}")
|
519 |
-
|
520 |
-
# Try to extract dates from query
|
521 |
-
parsed_date = self.parse_relative_date(user_query, detected_language)
|
522 |
-
if parsed_date:
|
523 |
-
print(f"Parsed relative date: {parsed_date}")
|
524 |
-
|
525 |
-
# 1. Route the query to determine which API endpoint to call
|
526 |
-
new_start_time = time.time()
|
527 |
-
router_result = self.router_chain.invoke({
|
528 |
-
"endpoints_documentation": json.dumps(self.endpoints_documentation, indent=2),
|
529 |
"user_query": user_query,
|
530 |
"detected_language": detected_language,
|
531 |
-
"
|
532 |
-
"
|
533 |
})
|
534 |
-
|
535 |
-
# TODO: remove the print statement in production
|
536 |
-
print('End time of generating response: ', time.time() - new_start_time)
|
537 |
|
538 |
-
#
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
# Clean the response first
|
543 |
-
cleaned_response = route_result
|
544 |
-
|
545 |
-
# Remove any comments (both single-line and multi-line)
|
546 |
-
cleaned_response = re.sub(r'//.*?$', '', cleaned_response, flags=re.MULTILINE)
|
547 |
cleaned_response = re.sub(r'/\*.*?\*/', '', cleaned_response, flags=re.DOTALL)
|
548 |
-
|
549 |
-
# Remove any trailing commas
|
550 |
cleaned_response = re.sub(r',(\s*[}\]])', r'\1', cleaned_response)
|
551 |
|
552 |
-
# Try different methods to parse the JSON response
|
553 |
try:
|
554 |
-
|
555 |
-
|
556 |
except json.JSONDecodeError:
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
|
582 |
-
|
583 |
-
if 'patient_id' in parsed_route['params']:
|
584 |
-
parsed_route['params']['patient_id'] = self.user_id
|
585 |
-
|
586 |
-
# Inject parsed date if available and a date parameter exists
|
587 |
-
date_params = ['appointment_date', 'date', 'schedule_date', 'date_time', 'new_date_time']
|
588 |
-
if parsed_date:
|
589 |
-
for param in date_params:
|
590 |
-
if param in parsed_route['params']:
|
591 |
-
parsed_route['params'][param] = parsed_date
|
592 |
-
|
593 |
-
print('Parsed route: ', parsed_route)
|
594 |
-
print(f"Routing completed in {time.time() - start_time:.2f} seconds")
|
595 |
-
|
596 |
-
# 3. Make the backend API call
|
597 |
-
backend_response = self.backend_call(parsed_route)
|
598 |
-
|
599 |
-
# 4. Generate user-friendly response
|
600 |
-
user_friendly_result = self.user_response_chain.invoke({
|
601 |
"user_query": user_query,
|
602 |
-
"api_response": json.dumps(backend_response, indent=2),
|
603 |
"detected_language": detected_language,
|
604 |
"sentiment_analysis": json.dumps(sentiment_result),
|
605 |
-
"
|
606 |
})
|
607 |
-
print('user response: ', user_friendly_result["user_friendly_response"])
|
608 |
-
|
609 |
-
print(f"Total processing time: {time.time() - start_time:.2f} seconds")
|
610 |
|
611 |
-
return
|
612 |
-
"routing_info": parsed_route,
|
613 |
-
"api_response": backend_response,
|
614 |
-
"user_friendly_response": user_friendly_result["user_friendly_response"],
|
615 |
-
"detected_language": detected_language,
|
616 |
-
"sentiment": sentiment_result,
|
617 |
-
"keywords": extracted_keywords
|
618 |
-
}
|
619 |
|
620 |
except Exception as e:
|
621 |
-
|
622 |
-
|
623 |
-
"
|
624 |
-
|
625 |
-
|
626 |
-
print(f"Error: {error_detail['error']}")
|
627 |
-
print(f"Traceback: {error_detail['traceback']}")
|
628 |
-
return error_detail
|
629 |
|
630 |
def backend_call(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
631 |
-
"""
|
632 |
-
Make the actual API call to the backend with retry logic
|
633 |
-
"""
|
634 |
endpoint_url = data.get('endpoint')
|
635 |
endpoint_method = data.get('method')
|
636 |
-
endpoint_params = data.get('params', {}).copy()
|
637 |
|
638 |
-
|
639 |
-
|
640 |
-
|
641 |
|
642 |
-
# Add retry logic for more robust API calls
|
643 |
retries = 0
|
644 |
while retries < self.max_retries:
|
645 |
try:
|
@@ -648,24 +572,17 @@ class AIAgent:
|
|
648 |
self.BASE_URL + endpoint_url,
|
649 |
params=endpoint_params,
|
650 |
headers=self.headers,
|
651 |
-
timeout=10 # Add timeout for backend calls
|
652 |
-
)
|
653 |
-
elif endpoint_method.upper() == 'POST': # POST or other methods
|
654 |
-
response = requests.post(
|
655 |
-
self.BASE_URL + endpoint_url,
|
656 |
-
json=endpoint_params,
|
657 |
-
headers=self.headers,
|
658 |
timeout=10
|
659 |
)
|
660 |
-
elif endpoint_method.upper()
|
661 |
-
response = requests.
|
|
|
662 |
self.BASE_URL + endpoint_url,
|
663 |
json=endpoint_params,
|
664 |
headers=self.headers,
|
665 |
timeout=10
|
666 |
)
|
667 |
|
668 |
-
# Check if response status is success
|
669 |
response.raise_for_status()
|
670 |
return response.json()
|
671 |
|
@@ -678,33 +595,207 @@ class AIAgent:
|
|
678 |
"status_code": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
|
679 |
}
|
680 |
|
681 |
-
print(f"API call attempt {retries} failed, retrying in {self.retry_delay} seconds...")
|
682 |
time.sleep(self.retry_delay)
|
683 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
684 |
|
685 |
-
# Initialize the AI agent singleton
|
686 |
-
# ai_agent = AIAgent()
|
687 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
688 |
|
689 |
-
|
|
|
690 |
# if __name__ == "__main__":
|
691 |
-
#
|
|
|
692 |
|
693 |
-
#
|
694 |
-
#
|
695 |
-
#
|
696 |
-
#
|
697 |
-
#
|
698 |
|
699 |
-
#
|
700 |
-
#
|
701 |
-
#
|
702 |
-
#
|
703 |
-
#
|
704 |
-
#
|
705 |
-
|
706 |
-
#
|
707 |
-
|
708 |
|
709 |
# Fast api section
|
710 |
from fastapi import FastAPI, HTTPException
|
@@ -719,11 +810,10 @@ app = FastAPI(
|
|
719 |
)
|
720 |
|
721 |
# Initialize the AI agent
|
722 |
-
agent =
|
723 |
|
724 |
class QueryRequest(BaseModel):
|
725 |
query: str
|
726 |
-
language: Optional[str] = None
|
727 |
|
728 |
class QueryResponse(BaseModel):
|
729 |
routing_info: Dict[str, Any]
|
@@ -738,7 +828,7 @@ async def process_query(request: QueryRequest):
|
|
738 |
Process a user query and return a response
|
739 |
"""
|
740 |
try:
|
741 |
-
response = agent.
|
742 |
return response
|
743 |
except Exception as e:
|
744 |
raise HTTPException(status_code=500, detail=str(e))
|
|
|
4 |
import traceback
|
5 |
import time
|
6 |
import os
|
7 |
+
from typing import Dict, Any, List, Optional, Tuple
|
8 |
from datetime import datetime, timedelta
|
9 |
|
10 |
# Updated imports for pydantic
|
|
|
27 |
from endpoints_documentation import endpoints_documentation
|
28 |
|
29 |
# Set environment variables for HuggingFace
|
30 |
+
# if os.name == 'posix' and os.uname().sysname == 'Darwin': # Check if running on macOS
|
31 |
+
# os.environ["HF_HOME"] = os.path.expanduser("~/Library/Caches/huggingface")
|
32 |
+
# os.environ["TRANSFORMERS_CACHE"] = os.path.expanduser("~/Library/Caches/huggingface/transformers")
|
33 |
+
# else:
|
34 |
os.environ["HF_HOME"] = "/tmp/huggingface"
|
35 |
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
|
36 |
|
37 |
+
|
38 |
+
class ChatMessage(BaseModel):
|
39 |
+
"""Data model for chat messages"""
|
40 |
+
message_id: str = Field(..., description="Unique identifier for the message")
|
41 |
+
user_id: str = Field(..., description="User identifier")
|
42 |
+
message: str = Field(..., description="The user's message")
|
43 |
+
timestamp: datetime = Field(default_factory=datetime.now, description="When the message was sent")
|
44 |
+
language: str = Field(default="english", description="Detected language of the message")
|
45 |
+
|
46 |
+
|
47 |
+
class ChatResponse(BaseModel):
|
48 |
+
"""Data model for chatbot responses"""
|
49 |
+
response_id: str = Field(..., description="Unique identifier for the response")
|
50 |
+
response_type: str = Field(..., description="Type of response: 'conversation' or 'api_action'")
|
51 |
+
message: str = Field(..., description="The chatbot's response message")
|
52 |
+
api_call_made: bool = Field(default=False, description="Whether an API call was made")
|
53 |
+
api_data: Optional[Dict[str, Any]] = Field(default=None, description="API response data if applicable")
|
54 |
+
language: str = Field(default="english", description="Language of the response")
|
55 |
+
timestamp: datetime = Field(default_factory=datetime.now, description="When the response was generated")
|
56 |
+
|
57 |
+
|
58 |
class EndpointRequest(BaseModel):
|
59 |
"""Data model for API endpoint requests"""
|
60 |
endpoint: str = Field(..., description="The API endpoint path to call")
|
|
|
63 |
missing_required: List[str] = Field(default_factory=list, description="Any required parameters that are missing")
|
64 |
|
65 |
|
66 |
+
class HealthcareChatbot:
|
67 |
def __init__(self):
|
68 |
self.endpoints_documentation = endpoints_documentation
|
69 |
+
self.ollama_base_url = "http://localhost:11434"
|
70 |
+
self.model_name = "gemma3"
|
71 |
+
self.BASE_URL = 'https://d623-105-196-69-205.ngrok-free.app'
|
72 |
+
self.headers = {'Content-type': 'application/json'}
|
73 |
+
self.user_id = '5c745974-f1e6-4a9d-b93f-0e0aa75c5b09'
|
|
|
|
|
|
|
74 |
self.max_retries = 3
|
75 |
+
self.retry_delay = 2
|
76 |
|
77 |
+
# Store conversation history
|
78 |
+
self.conversation_history = []
|
79 |
+
self.max_history_length = 10 # Keep last 10 exchanges
|
80 |
|
81 |
+
# Initialize components
|
82 |
+
self._initialize_language_tools()
|
83 |
self._initialize_llm()
|
84 |
self._initialize_parsers_and_chains()
|
|
|
|
|
85 |
self._initialize_date_parser()
|
86 |
+
|
87 |
+
print("Healthcare Chatbot initialized successfully!")
|
88 |
+
self._print_welcome_message()
|
89 |
+
|
90 |
+
def _print_welcome_message(self):
|
91 |
+
"""Print welcome message in both languages"""
|
92 |
+
print("\n" + "="*60)
|
93 |
+
print("🏥 HEALTHCARE CHATBOT READY")
|
94 |
+
print("="*60)
|
95 |
+
print("English: Hello! I'm your healthcare assistant. I can help you with:")
|
96 |
+
print("• Booking and managing appointments")
|
97 |
+
print("• Finding hospital information")
|
98 |
+
print("• Viewing your medical records")
|
99 |
+
print("• General healthcare questions")
|
100 |
+
print()
|
101 |
+
print("Arabic: مرحباً! أنا مساعدك الطبي. يمكنني مساعدتك في:")
|
102 |
+
print("• حجز وإدارة المواعيد")
|
103 |
+
print("• العثور على معلومات المستشفى")
|
104 |
+
print("• عرض سجلاتك الطبية")
|
105 |
+
print("• الأسئلة الطبية العامة")
|
106 |
+
print("="*60)
|
107 |
+
print("Type 'quit' or 'خروج' to exit\n")
|
108 |
|
109 |
def _initialize_language_tools(self):
|
110 |
+
"""Initialize language processing tools"""
|
|
|
|
|
|
|
|
|
111 |
try:
|
112 |
+
self.embeddings = HuggingFaceEmbeddings(model_name="intfloat/multilingual-e5-large")
|
113 |
self.language_classifier = pipeline(
|
114 |
"text-classification",
|
115 |
model="papluca/xlm-roberta-base-language-detection",
|
116 |
top_k=1
|
117 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
self.sentiment_analyzer = pipeline(
|
119 |
"sentiment-analysis",
|
120 |
model="cardiffnlp/twitter-xlm-roberta-base-sentiment"
|
121 |
)
|
122 |
+
print("✓ Language processing models loaded successfully")
|
123 |
except Exception as e:
|
124 |
+
print(f"⚠ Warning: Some language models failed to load: {e}")
|
125 |
+
self.language_classifier = None
|
126 |
self.sentiment_analyzer = None
|
127 |
|
128 |
def _initialize_date_parser(self):
|
129 |
+
"""Initialize date parsing model"""
|
130 |
try:
|
131 |
self.date_parser = pipeline(
|
132 |
"token-classification",
|
133 |
model="Jean-Baptiste/roberta-large-ner-english",
|
134 |
aggregation_strategy="simple"
|
135 |
)
|
|
|
136 |
except Exception as e:
|
137 |
+
print(f"⚠ Warning: Date parsing model failed to load: {e}")
|
138 |
self.date_parser = None
|
139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
def _initialize_llm(self):
|
141 |
+
"""Initialize the LLM"""
|
|
|
142 |
callbacks = [StreamingStdOutCallbackHandler()]
|
|
|
|
|
143 |
self.llm = OllamaLLM(
|
144 |
model=self.model_name,
|
145 |
base_url=self.ollama_base_url,
|
146 |
callbacks=callbacks,
|
147 |
temperature=0.7,
|
148 |
+
num_ctx=8192,
|
149 |
top_p=0.9,
|
150 |
+
request_timeout=60,
|
151 |
)
|
152 |
|
153 |
def _initialize_parsers_and_chains(self):
|
154 |
+
"""Initialize all prompt templates and chains"""
|
|
|
155 |
self.json_parser = JsonOutputParser(pydantic_object=EndpointRequest)
|
156 |
|
157 |
+
# Intent classification prompt
|
158 |
+
self.intent_classifier_template = PromptTemplate(
|
159 |
+
template="""
|
160 |
+
You are an intent classifier for a healthcare chatbot. Analyze the user's message and determine if it requires an API call or is conversational.
|
161 |
+
|
162 |
+
=== ANALYSIS CONTEXT ===
|
163 |
+
User Message: {user_query}
|
164 |
+
Language: {detected_language}
|
165 |
+
Conversation History: {conversation_history}
|
166 |
+
|
167 |
+
=== AVAILABLE API ENDPOINTS ===
|
168 |
+
{endpoints_documentation}
|
169 |
+
|
170 |
+
=== CLASSIFICATION TASK ===
|
171 |
+
Determine if the user's message requires:
|
172 |
+
1. API_ACTION: Specific healthcare action (book appointment, view records, etc.)
|
173 |
+
2. CONVERSATION: General chat, greeting, questions not requiring backend data
|
174 |
+
|
175 |
+
=== RESPONSE FORMAT ===
|
176 |
+
Respond with EXACTLY this JSON structure:
|
177 |
+
{{
|
178 |
+
"intent": "API_ACTION" or "CONVERSATION",
|
179 |
+
"confidence": 0.95,
|
180 |
+
"reasoning": "Brief explanation of classification decision",
|
181 |
+
"requires_backend": true or false
|
182 |
+
}}
|
183 |
+
|
184 |
+
=== CLASSIFICATION RULES ===
|
185 |
+
Choose API_ACTION for:
|
186 |
+
- Booking, canceling, or viewing appointments
|
187 |
+
- Requesting medical records or test results
|
188 |
+
- Hospital information queries (locations, hours, etc.)
|
189 |
+
- Medication management requests
|
190 |
+
- Specific patient data requests
|
191 |
+
|
192 |
+
Choose CONVERSATION for:
|
193 |
+
- Greetings and pleasantries
|
194 |
+
- General health advice (not patient-specific)
|
195 |
+
- Explanations of medical terms
|
196 |
+
- Small talk or casual questions
|
197 |
+
- Questions about the chatbot itself
|
198 |
+
|
199 |
+
Classify the intent:""",
|
200 |
+
input_variables=["user_query", "detected_language", "conversation_history", "endpoints_documentation"]
|
201 |
+
)
|
202 |
+
|
203 |
+
# API routing prompt (reuse existing router_prompt_template)
|
204 |
self.router_prompt_template = PromptTemplate(
|
205 |
template="""
|
206 |
+
You are a precise API routing assistant. Your job is to analyze user queries and select the correct API endpoint with proper parameters.
|
207 |
|
208 |
=== ENDPOINT DOCUMENTATION ===
|
209 |
{endpoints_documentation}
|
|
|
275 |
- User asks "what appointments today" → Use appointment listing with date filter
|
276 |
- User wants to "update medication" → Use medication update endpoint with patient_id
|
277 |
|
278 |
+
Think step by step and be precise with your endpoint selection and parameter extraction.:""",
|
279 |
+
input_variables=["endpoints_documentation", "user_query", "detected_language",
|
280 |
+
"extracted_keywords", "sentiment_analysis", "conversation_history"]
|
281 |
+
)
|
282 |
+
|
283 |
+
# Conversational response prompt
|
284 |
+
self.conversation_template = PromptTemplate(
|
285 |
+
template="""
|
286 |
+
You are a friendly and professional healthcare chatbot assistant.
|
287 |
+
|
288 |
+
=== RESPONSE GUIDELINES ===
|
289 |
+
- Respond ONLY in {detected_language}
|
290 |
+
- Be helpful, empathetic, and professional
|
291 |
+
- Keep responses concise but informative
|
292 |
+
- Use appropriate medical terminology when needed
|
293 |
+
- Maintain a caring and supportive tone
|
294 |
+
|
295 |
+
=== CONTEXT ===
|
296 |
+
User Message: {user_query}
|
297 |
+
Language: {detected_language}
|
298 |
+
Sentiment: {sentiment_analysis}
|
299 |
+
Conversation History: {conversation_history}
|
300 |
+
|
301 |
+
=== LANGUAGE-SPECIFIC INSTRUCTIONS ===
|
302 |
+
|
303 |
+
FOR ARABIC RESPONSES:
|
304 |
+
- Use Modern Standard Arabic (الفصحى)
|
305 |
+
- Be respectful and formal as appropriate in Arabic culture
|
306 |
+
- Use proper Arabic medical terminology
|
307 |
+
- Keep sentences clear and grammatically correct
|
308 |
+
|
309 |
+
FOR ENGLISH RESPONSES:
|
310 |
+
- Use clear, professional English
|
311 |
+
- Be warm and approachable
|
312 |
+
- Use appropriate medical terminology
|
313 |
+
|
314 |
+
=== RESPONSE RULES ===
|
315 |
+
1. Address the user's question or comment directly
|
316 |
+
2. Provide helpful information when possible
|
317 |
+
3. If you cannot help with something specific, explain what you CAN help with
|
318 |
+
4. Never provide specific medical advice - always recommend consulting healthcare professionals
|
319 |
+
5. Be encouraging and supportive
|
320 |
+
6. Do NOT mix languages in your response
|
321 |
+
7. End responses naturally without asking multiple questions
|
322 |
+
|
323 |
+
Generate a helpful conversational response:""",
|
324 |
+
input_variables=["user_query", "detected_language", "sentiment_analysis", "conversation_history"]
|
325 |
+
)
|
326 |
+
|
327 |
+
# API response formatting prompt (reuse existing user_response_template)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
self.user_response_template = PromptTemplate(
|
329 |
template="""
|
330 |
You are a professional healthcare assistant. Generate clear, accurate responses using EXACT data from the system.
|
|
|
411 |
=== FINAL INSTRUCTION ===
|
412 |
Respond ONLY in the requested language. Do NOT provide translations, explanations, or additional text in any other language. Stop immediately after answering the user's question.
|
413 |
""",
|
414 |
+
input_variables=["user_query", "api_response", "detected_language", "conversation_history"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
)
|
416 |
|
417 |
+
# Create chains
|
418 |
+
self.intent_chain = LLMChain(llm=self.llm, prompt=self.intent_classifier_template)
|
419 |
+
self.router_chain = LLMChain(llm=self.llm, prompt=self.router_prompt_template)
|
420 |
+
self.conversation_chain = LLMChain(llm=self.llm, prompt=self.conversation_template)
|
421 |
+
self.api_response_chain = LLMChain(llm=self.llm, prompt=self.user_response_template)
|
422 |
+
|
423 |
+
def detect_language(self, text):
|
424 |
+
"""Detect language of the input text"""
|
425 |
+
if self.language_classifier and len(text.strip()) > 3:
|
426 |
+
try:
|
427 |
+
result = self.language_classifier(text)
|
428 |
+
detected_lang = result[0][0]['label']
|
429 |
+
confidence = result[0][0]['score']
|
430 |
+
|
431 |
+
if detected_lang in ['ar', 'arabic']:
|
432 |
+
return "arabic"
|
433 |
+
elif detected_lang in ['en', 'english']:
|
434 |
+
return "english"
|
435 |
+
elif confidence > 0.8:
|
436 |
+
return "english" # Default to English for unsupported languages
|
437 |
+
except:
|
438 |
+
pass
|
439 |
|
440 |
+
# Fallback: Basic Arabic detection
|
441 |
+
arabic_pattern = re.compile(r'[\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF]+')
|
442 |
+
if arabic_pattern.search(text):
|
443 |
+
return "arabic"
|
|
|
444 |
|
445 |
+
return "english"
|
446 |
+
|
447 |
+
def analyze_sentiment(self, text):
|
448 |
+
"""Analyze sentiment of the text"""
|
449 |
+
if self.sentiment_analyzer and len(text.strip()) > 3:
|
450 |
+
try:
|
451 |
+
result = self.sentiment_analyzer(text)
|
452 |
+
return {
|
453 |
+
"sentiment": result[0]['label'],
|
454 |
+
"score": result[0]['score']
|
455 |
+
}
|
456 |
+
except:
|
457 |
+
pass
|
458 |
|
459 |
+
return {"sentiment": "NEUTRAL", "score": 0.5}
|
460 |
+
|
461 |
+
def extract_keywords(self, text):
|
462 |
+
"""Extract keywords from text"""
|
463 |
+
# Simple keyword extraction
|
464 |
+
words = re.findall(r'\b\w+\b', text.lower())
|
465 |
+
# Filter out common words and keep meaningful ones
|
466 |
+
stopwords = {'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were'}
|
467 |
+
keywords = [w for w in words if len(w) > 3 and w not in stopwords]
|
468 |
+
return list(set(keywords))[:5] # Return top 5 unique keywords
|
469 |
+
|
470 |
+
def get_conversation_context(self):
|
471 |
+
"""Get recent conversation history as context"""
|
472 |
+
if not self.conversation_history:
|
473 |
+
return "No previous conversation"
|
474 |
|
475 |
+
context = []
|
476 |
+
for item in self.conversation_history[-3:]: # Last 3 exchanges
|
477 |
+
context.append(f"User: {item['user_message']}")
|
478 |
+
context.append(f"Bot: {item['bot_response'][:100]}...") # Truncate long responses
|
479 |
|
480 |
+
return " | ".join(context)
|
481 |
+
|
482 |
+
def add_to_history(self, user_message, bot_response, response_type):
|
483 |
+
"""Add exchange to conversation history"""
|
484 |
+
self.conversation_history.append({
|
485 |
+
'timestamp': datetime.now(),
|
486 |
+
'user_message': user_message,
|
487 |
+
'bot_response': bot_response,
|
488 |
+
'response_type': response_type
|
489 |
+
})
|
|
|
|
|
|
|
490 |
|
491 |
+
# Keep only recent history
|
492 |
+
if len(self.conversation_history) > self.max_history_length:
|
493 |
+
self.conversation_history = self.conversation_history[-self.max_history_length:]
|
494 |
|
495 |
+
def classify_intent(self, user_query, detected_language):
|
496 |
+
"""Classify if the user query requires API action or is conversational"""
|
|
|
|
|
497 |
try:
|
498 |
+
result = self.intent_chain.invoke({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
"user_query": user_query,
|
500 |
"detected_language": detected_language,
|
501 |
+
"conversation_history": self.get_conversation_context(),
|
502 |
+
"endpoints_documentation": json.dumps(self.endpoints_documentation, indent=2)
|
503 |
})
|
|
|
|
|
|
|
504 |
|
505 |
+
# Parse the JSON response
|
506 |
+
intent_text = result["text"]
|
507 |
+
# Clean and parse JSON
|
508 |
+
cleaned_response = re.sub(r'//.*?$', '', intent_text, flags=re.MULTILINE)
|
|
|
|
|
|
|
|
|
|
|
509 |
cleaned_response = re.sub(r'/\*.*?\*/', '', cleaned_response, flags=re.DOTALL)
|
|
|
|
|
510 |
cleaned_response = re.sub(r',(\s*[}\]])', r'\1', cleaned_response)
|
511 |
|
|
|
512 |
try:
|
513 |
+
intent_data = json.loads(cleaned_response)
|
514 |
+
return intent_data
|
515 |
except json.JSONDecodeError:
|
516 |
+
# Try to extract JSON from the response
|
517 |
+
json_match = re.search(r'\{.*?\}', cleaned_response, re.DOTALL)
|
518 |
+
if json_match:
|
519 |
+
intent_data = json.loads(json_match.group(0))
|
520 |
+
return intent_data
|
521 |
+
else:
|
522 |
+
# Default classification if parsing fails
|
523 |
+
return {
|
524 |
+
"intent": "CONVERSATION",
|
525 |
+
"confidence": 0.5,
|
526 |
+
"reasoning": "Failed to parse LLM response",
|
527 |
+
"requires_backend": False
|
528 |
+
}
|
529 |
+
except Exception as e:
|
530 |
+
print(f"Error in intent classification: {e}")
|
531 |
+
return {
|
532 |
+
"intent": "CONVERSATION",
|
533 |
+
"confidence": 0.5,
|
534 |
+
"reasoning": f"Error in classification: {str(e)}",
|
535 |
+
"requires_backend": False
|
536 |
+
}
|
537 |
+
|
538 |
+
def handle_conversation(self, user_query, detected_language, sentiment_result):
|
539 |
+
"""Handle conversational responses"""
|
540 |
+
try:
|
541 |
+
result = self.conversation_chain.invoke({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
"user_query": user_query,
|
|
|
543 |
"detected_language": detected_language,
|
544 |
"sentiment_analysis": json.dumps(sentiment_result),
|
545 |
+
"conversation_history": self.get_conversation_context()
|
546 |
})
|
|
|
|
|
|
|
547 |
|
548 |
+
return result["text"].strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
549 |
|
550 |
except Exception as e:
|
551 |
+
# Fallback response
|
552 |
+
if detected_language == "arabic":
|
553 |
+
return "أعتذر، واجهت مشكلة في المعالجة. كيف يمكنني مساعدتك؟"
|
554 |
+
else:
|
555 |
+
return "I apologize, I encountered a processing issue. How can I help you?"
|
|
|
|
|
|
|
556 |
|
557 |
def backend_call(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
558 |
+
"""Make API call to backend with retry logic"""
|
|
|
|
|
559 |
endpoint_url = data.get('endpoint')
|
560 |
endpoint_method = data.get('method')
|
561 |
+
endpoint_params = data.get('params', {}).copy()
|
562 |
|
563 |
+
# Inject patient_id if needed
|
564 |
+
if 'patient_id' in endpoint_params:
|
565 |
+
endpoint_params['patient_id'] = self.user_id
|
566 |
|
|
|
567 |
retries = 0
|
568 |
while retries < self.max_retries:
|
569 |
try:
|
|
|
572 |
self.BASE_URL + endpoint_url,
|
573 |
params=endpoint_params,
|
574 |
headers=self.headers,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
575 |
timeout=10
|
576 |
)
|
577 |
+
elif endpoint_method.upper() in ['POST', 'PUT', 'DELETE']:
|
578 |
+
response = requests.request(
|
579 |
+
endpoint_method.upper(),
|
580 |
self.BASE_URL + endpoint_url,
|
581 |
json=endpoint_params,
|
582 |
headers=self.headers,
|
583 |
timeout=10
|
584 |
)
|
585 |
|
|
|
586 |
response.raise_for_status()
|
587 |
return response.json()
|
588 |
|
|
|
595 |
"status_code": getattr(e.response, 'status_code', None) if hasattr(e, 'response') else None
|
596 |
}
|
597 |
|
|
|
598 |
time.sleep(self.retry_delay)
|
599 |
|
600 |
+
def handle_api_action(self, user_query, detected_language, sentiment_result, keywords):
|
601 |
+
"""Handle API-based actions"""
|
602 |
+
try:
|
603 |
+
# Route the query to determine API endpoint
|
604 |
+
router_result = self.router_chain.invoke({
|
605 |
+
"endpoints_documentation": json.dumps(self.endpoints_documentation, indent=2),
|
606 |
+
"user_query": user_query,
|
607 |
+
"detected_language": detected_language,
|
608 |
+
"extracted_keywords": ", ".join(keywords),
|
609 |
+
"sentiment_analysis": json.dumps(sentiment_result),
|
610 |
+
"conversation_history": self.get_conversation_context()
|
611 |
+
})
|
612 |
+
|
613 |
+
# Parse router response
|
614 |
+
route_text = router_result["text"]
|
615 |
+
cleaned_response = re.sub(r'//.*?$', '', route_text, flags=re.MULTILINE)
|
616 |
+
cleaned_response = re.sub(r'/\*.*?\*/', '', cleaned_response, flags=re.DOTALL)
|
617 |
+
cleaned_response = re.sub(r',(\s*[}\]])', r'\1', cleaned_response)
|
618 |
+
|
619 |
+
try:
|
620 |
+
parsed_route = json.loads(cleaned_response)
|
621 |
+
except json.JSONDecodeError:
|
622 |
+
json_match = re.search(r'\{.*?\}', cleaned_response, re.DOTALL)
|
623 |
+
if json_match:
|
624 |
+
parsed_route = json.loads(json_match.group(0))
|
625 |
+
else:
|
626 |
+
raise ValueError("Could not parse routing response")
|
627 |
+
|
628 |
+
# Make backend API call
|
629 |
+
api_response = self.backend_call(parsed_route)
|
630 |
+
|
631 |
+
# Generate user-friendly response
|
632 |
+
user_response_result = self.api_response_chain.invoke({
|
633 |
+
"user_query": user_query,
|
634 |
+
"api_response": json.dumps(api_response, indent=2),
|
635 |
+
"detected_language": detected_language,
|
636 |
+
"conversation_history": self.get_conversation_context()
|
637 |
+
})
|
638 |
+
|
639 |
+
return {
|
640 |
+
"response": user_response_result["text"].strip(),
|
641 |
+
"api_data": api_response,
|
642 |
+
"routing_info": parsed_route
|
643 |
+
}
|
644 |
+
|
645 |
+
except Exception as e:
|
646 |
+
# Fallback error response
|
647 |
+
if detected_language == "arabic":
|
648 |
+
error_msg = "أعتذر، لم أتمكن من معالجة طلبك. يرجى المحاولة مرة أخرى أو صياغة السؤال بطريقة مختلفة."
|
649 |
+
else:
|
650 |
+
error_msg = "I apologize, I couldn't process your request. Please try again or rephrase your question."
|
651 |
+
|
652 |
+
return {
|
653 |
+
"response": error_msg,
|
654 |
+
"api_data": {"error": str(e)},
|
655 |
+
"routing_info": None
|
656 |
+
}
|
657 |
+
|
658 |
+
def chat(self, user_message: str) -> ChatResponse:
|
659 |
+
"""Main chat method that handles user messages"""
|
660 |
+
start_time = time.time()
|
661 |
+
|
662 |
+
# Check for exit commands
|
663 |
+
exit_commands = ['quit', 'exit', 'bye', 'خروج', 'وداعا', 'مع السلامة']
|
664 |
+
if user_message.lower().strip() in exit_commands:
|
665 |
+
return ChatResponse(
|
666 |
+
response_id=f"resp_{int(time.time())}",
|
667 |
+
response_type="conversation",
|
668 |
+
message="Goodbye! Take care of your health! / وداعاً! اعتن بصحتك!",
|
669 |
+
language="bilingual"
|
670 |
+
)
|
671 |
+
|
672 |
+
try:
|
673 |
+
# Language detection and analysis
|
674 |
+
detected_language = self.detect_language(user_message)
|
675 |
+
sentiment_result = self.analyze_sentiment(user_message)
|
676 |
+
keywords = self.extract_keywords(user_message)
|
677 |
+
|
678 |
+
print(f"🔍 Language: {detected_language} | Sentiment: {sentiment_result['sentiment']} | Keywords: {keywords}")
|
679 |
+
|
680 |
+
# Classify intent
|
681 |
+
intent_data = self.classify_intent(user_message, detected_language)
|
682 |
+
print(f"🎯 Intent: {intent_data['intent']} (confidence: {intent_data.get('confidence', 'N/A')})")
|
683 |
+
|
684 |
+
# Handle based on intent
|
685 |
+
if intent_data["intent"] == "API_ACTION" and intent_data.get("requires_backend", False):
|
686 |
+
# Handle API-based actions
|
687 |
+
print("🔗 Processing API action...")
|
688 |
+
action_result = self.handle_api_action(user_message, detected_language, sentiment_result, keywords)
|
689 |
+
|
690 |
+
response = ChatResponse(
|
691 |
+
response_id=f"resp_{int(time.time())}",
|
692 |
+
response_type="api_action",
|
693 |
+
message=action_result["response"],
|
694 |
+
api_call_made=True,
|
695 |
+
api_data=action_result["api_data"],
|
696 |
+
language=detected_language
|
697 |
+
)
|
698 |
+
|
699 |
+
else:
|
700 |
+
# Handle conversational responses
|
701 |
+
print("💬 Processing conversational response...")
|
702 |
+
conv_response = self.handle_conversation(user_message, detected_language, sentiment_result)
|
703 |
+
|
704 |
+
response = ChatResponse(
|
705 |
+
response_id=f"resp_{int(time.time())}",
|
706 |
+
response_type="conversation",
|
707 |
+
message=conv_response,
|
708 |
+
api_call_made=False,
|
709 |
+
language=detected_language
|
710 |
+
)
|
711 |
+
|
712 |
+
# Add to conversation history
|
713 |
+
self.add_to_history(user_message, response.message, response.response_type)
|
714 |
+
|
715 |
+
print(f"⏱️ Processing time: {time.time() - start_time:.2f}s")
|
716 |
+
return response
|
717 |
+
|
718 |
+
except Exception as e:
|
719 |
+
print(f"❌ Error in chat processing: {e}")
|
720 |
+
error_msg = "I apologize for the technical issue. Please try again. / أعتذر عن المشكلة التقنية. يرجى المحاولة مرة أخرى."
|
721 |
+
|
722 |
+
return ChatResponse(
|
723 |
+
response_id=f"resp_{int(time.time())}",
|
724 |
+
response_type="conversation",
|
725 |
+
message=error_msg,
|
726 |
+
api_call_made=False,
|
727 |
+
language="bilingual"
|
728 |
+
)
|
729 |
+
|
730 |
+
def start_interactive_chat(self):
|
731 |
+
"""Start an interactive chat session"""
|
732 |
+
print("🚀 Starting interactive chat session...")
|
733 |
+
|
734 |
+
while True:
|
735 |
+
try:
|
736 |
+
# Get user input
|
737 |
+
user_input = input("\n👤 You: ").strip()
|
738 |
+
|
739 |
+
if not user_input:
|
740 |
+
continue
|
741 |
+
|
742 |
+
# Process the message
|
743 |
+
print("🤖 Processing...")
|
744 |
+
response = self.chat(user_input)
|
745 |
+
|
746 |
+
# Display response
|
747 |
+
print(f"\n🏥 Healthcare Bot: {response.message}")
|
748 |
+
|
749 |
+
# Show additional info if API call was made
|
750 |
+
if response.api_call_made and response.api_data:
|
751 |
+
if "error" not in response.api_data:
|
752 |
+
print("✅ Successfully retrieved information from healthcare system")
|
753 |
+
else:
|
754 |
+
print("⚠️ There was an issue accessing the healthcare system")
|
755 |
+
|
756 |
+
# Check for exit
|
757 |
+
if "Goodbye" in response.message or "وداعاً" in response.message:
|
758 |
+
break
|
759 |
+
|
760 |
+
except KeyboardInterrupt:
|
761 |
+
print("\n\n👋 Chat session ended. Goodbye!")
|
762 |
+
break
|
763 |
+
except Exception as e:
|
764 |
+
print(f"\n❌ Unexpected error: {e}")
|
765 |
+
print("The chat session will continue...")
|
766 |
|
|
|
|
|
767 |
|
768 |
+
# Create a simple function to start the chatbot
|
769 |
+
# def start_healthcare_chatbot():
|
770 |
+
# """Initialize and start the healthcare chatbot"""
|
771 |
+
# try:
|
772 |
+
# chatbot = HealthcareChatbot()
|
773 |
+
# chatbot.start_interactive_chat()
|
774 |
+
# except Exception as e:
|
775 |
+
# print(f"Failed to start chatbot: {e}")
|
776 |
+
# print("Please check your Ollama installation and endpoint documentation.")
|
777 |
|
778 |
+
|
779 |
+
# Test the chatbot
|
780 |
# if __name__ == "__main__":
|
781 |
+
# You can test individual messages like this:
|
782 |
+
# chatbot = HealthcareChatbot()
|
783 |
|
784 |
+
# Test conversational message
|
785 |
+
# print("\n=== TESTING CONVERSATIONAL MESSAGE ===")
|
786 |
+
# conv_response = chatbot.chat("Hello, how are you today?")
|
787 |
+
# print(f"Response: {conv_response.message}")
|
788 |
+
# print(f"Type: {conv_response.response_type}")
|
789 |
|
790 |
+
# Test API action message
|
791 |
+
# print("\n=== TESTING API ACTION MESSAGE ===")
|
792 |
+
# api_response = chatbot.chat("I want to book an appointment tomorrow at 2 PM")
|
793 |
+
# print(f"Response: {api_response.message}")
|
794 |
+
# print(f"Type: {api_response.response_type}")
|
795 |
+
# print(f"API Called: {api_response.api_call_made}")
|
796 |
+
|
797 |
+
# Start interactive session (uncomment to run)
|
798 |
+
# start_healthcare_chatbot()
|
799 |
|
800 |
# Fast api section
|
801 |
from fastapi import FastAPI, HTTPException
|
|
|
810 |
)
|
811 |
|
812 |
# Initialize the AI agent
|
813 |
+
agent = HealthcareChatbot()
|
814 |
|
815 |
class QueryRequest(BaseModel):
|
816 |
query: str
|
|
|
817 |
|
818 |
class QueryResponse(BaseModel):
|
819 |
routing_info: Dict[str, Any]
|
|
|
828 |
Process a user query and return a response
|
829 |
"""
|
830 |
try:
|
831 |
+
response = agent.chat(request.query)
|
832 |
return response
|
833 |
except Exception as e:
|
834 |
raise HTTPException(status_code=500, detail=str(e))
|