File size: 3,659 Bytes
4096277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import asyncio
import json
import logging
from fastapi import HTTPException, WebSocket, status
from typing import Dict

class InRequest:
    def __init__(self):
        self.responses: Dict[str, asyncio.Future] = {}

class ConnectionManager:
    
    def __init__(self):
        self.available = None
        self.active_connections: Dict[str, WebSocket] = {}  # Maps socket ID to WebSocket connection
        self.in_request: Dict[str, InRequest] = {}  # Store pending response futures

    async def connect(self, socket_id: str, websocket: WebSocket):
        await websocket.accept()
        self.active_connections[socket_id] = websocket
        if self.available is None:
            self.available = socket_id
        return socket_id

    def disconnect(self, socket_id: str):
        if socket_id in self.active_connections:
            del self.active_connections[socket_id]
            if self.available == socket_id:
                self.available = None

    async def broadcast(self, message: str):
        for connection in self.active_connections.values():
            await connection.send_text(message)

    async def receive_text(self, socket_id: str):
        websocket = self.active_connections.get(socket_id)
        if websocket:
            return await websocket.receive_text()
        else:
            raise HTTPException(
                status_code=status.HTTP_502_BAD_GATEWAY, 
                detail=f"Socket ID {socket_id} not connected")

    async def send_text(self, socket_id: str, message: str):
        websocket = self.active_connections.get(socket_id)
        if websocket:
            await websocket.send_text(message)
        else:
            raise HTTPException(
                status_code=status.HTTP_502_BAD_GATEWAY, 
                detail="WebSocket connection not found.")

    async def send_bytes(self, socket_id: str, binary_data: bytes):
        websocket = self.active_connections.get(socket_id)
        if websocket:
            await websocket.send_bytes(binary_data)  # Send binary data
        else:
            raise HTTPException(
                status_code=status.HTTP_502_BAD_GATEWAY, 
                detail=f"Socket ID {socket_id} not connected")

    async def listen(self, socket_id:str, request_id:str) -> str:
        req = InRequest()
        # Create a Future for waiting for the response
        future = asyncio.get_event_loop().create_future()
        req.responses[request_id] = future
        self.in_request[socket_id] = req
        try:
            return await future  # Await the future until it's set with a response
        except asyncio.CancelledError:
            raise HTTPException(
                status_code=status.HTTP_502_BAD_GATEWAY, 
                detail=f"Socket ID {socket_id} not connected or canceled")

    async def notify(self, socket_id: str, message: str):
        logging.debug(message)
        # If there is a pending future for this socket, set the result
        if socket_id in self.in_request:
            request_id, payload = self.extract_message(message)
            if request_id is not None:
                self.in_request[socket_id].responses[request_id].set_result(payload)
                self.in_request.pop(socket_id, None)

    def extract_message(self, message:str):
        request_id = None
        payload = None
        logging.debug(message)

        try:
            o = json.loads(message)
            if o is not None:
                request_id, payload = o.get('request_id'), o.get('payload')
        except Exception as e:
            logging.warning(f"extract_message error: {str(e)}")

        return request_id, payload