File size: 9,132 Bytes
d067c3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.concurrency import run_in_threadpool
from gradio_client import Client
import gradio as gr
import uvicorn
import httpx
import websockets
import asyncio
from urllib.parse import urljoin, urlparse, unquote
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

class HttpClient:
    def __init__(self):
        # Configure the HTTP client with appropriate timeouts
        self.client = httpx.AsyncClient(
            timeout=httpx.Timeout(30.0),
            follow_redirects=False
        )
    
    async def forward_request(self, request: Request, target_url: str):
        """
        Forward an incoming request to a target URL
        """
        try:
            # Extract method, headers, and body from the incoming request
            method = request.method
            headers = dict(request.headers)
            
            # Remove headers that shouldn't be forwarded
            headers.pop("host", None)
            headers.pop("connection", None)
            
            # Get the request body
            body = await request.body()
            
            logger.info(f"Forwarding {method} request to {target_url}")
            
            # Forward the request to the target URL
            response = await self.client.request(
                method=method,
                url=target_url,
                headers=headers,
                content=body
            )
            
            # Handle the response from the target server
            response_headers = dict(response.headers)
            # Remove headers that shouldn't be forwarded from the response
            response_headers.pop("connection", None)
            response_headers.pop("transfer-encoding", None)
            
            return Response(
                content=response.content,
                status_code=response.status_code,
                headers=response_headers
            )
            
        except httpx.TimeoutException:
            logger.error(f"Timeout error while forwarding request to {target_url}")
            return Response(
                content="Request timeout error",
                status_code=504
            )
        except httpx.NetworkError as e:
            logger.error(f"Network error while forwarding request: {str(e)}")
            return Response(
                content=f"Network error: {str(e)}",
                status_code=502
            )
        except Exception as e:
            logger.error(f"Error forwarding request: {str(e)}")
            return Response(
                content=f"Request error: {str(e)}",
                status_code=500
            )
    
    async def close(self):
        await self.client.aclose()

# Initialize the HTTP client
http_client = HttpClient()

@app.get("/", response_class=HTMLResponse)
async def read_root():
    with open("index.html") as f:
        return f.read()

@app.post("/gp/{repo_id:path}/{api_name:path}")
async def gradio_client(repo_id: str, api_name: str, request: Request):
    client = Client(repo_id)
    data = await request.json()
    result = await run_in_threadpool(client.predict, *data["args"], api_name=f"/{api_name}")
    return result

@app.api_route("/wp/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def web_client(request: Request, path: str):
    """
    Main web client endpoint that forwards all requests to the target URL
    specified in the path or the 'X-Target-Url' header
    """
    # Prioritize URL in path if it starts with http:// or https://
    if path.lower().startswith("http://") or path.lower().startswith("https://"):
        target_url = path
    else:
        # Get the target URL from the header
        target_url = request.headers.get("X-Target-Url")
        
        # If we have a target URL from header and a path, combine them
        if target_url and path:
            # Validate the target URL from header
            try:
                parsed_url = urlparse(target_url)
                if not parsed_url.scheme or not parsed_url.netloc:
                    return Response(
                        content="Invalid X-Target-Url header",
                        status_code=400
                    )
            except Exception:
                return Response(
                    content="Invalid X-Target-Url header",
                    status_code=400
                )
            # Join the target URL with the path properly
            target_url = urljoin(target_url.rstrip('/') + '/', path.lstrip('/'))

    if not target_url:
        return Response(
            content="Missing X-Target-Url header or URL in path",
            status_code=400
        )

    # Validate the target URL
    try:
        parsed_url = urlparse(target_url)
        if not parsed_url.scheme or not parsed_url.netloc:
            return Response(
                content="Invalid X-Target-Url header or URL in path",
                status_code=400
            )
    except Exception:
        return Response(
            content="Invalid X-Target-Url header or URL in path",
            status_code=400
        )

    # Forward the request
    return await http_client.forward_request(request, target_url)

@app.websocket("/wp/{path:path}")
async def websocket_client(websocket: WebSocket, path: str):
    """
    WebSocket endpoint that forwards WebSocket connections to the target URL
    specified in the 'X-Target-Url' header or in the path
    """
    # Get the target URL from the header or path
    target_url = websocket.headers.get("X-Target-Url")
    
    # If no header, use path as target URL if it's a valid WebSocket URL
    if not target_url:
        # Handle URL-encoded paths
        decoded_path = path
        if path and '%' in path:
            # URL decode the path
            decoded_path = unquote(path)
        
        if decoded_path and (decoded_path.lower().startswith("ws://") or decoded_path.lower().startswith("wss://")):
            target_url = decoded_path
        else:
            await websocket.close(code=1008, reason="Missing X-Target-Url header or invalid URL in path")
            return
    
    # Validate the target URL
    try:
        parsed_url = urlparse(target_url)
        if not parsed_url.scheme or not parsed_url.netloc:
            await websocket.close(code=1008, reason="Invalid target URL")
            return
    except Exception:
        await websocket.close(code=1008, reason="Invalid target URL")
        return
    
    # Accept the WebSocket connection
    await websocket.accept()
    
    # Convert HTTP/HTTPS URL to WebSocket URL
    if target_url.lower().startswith("https://"):
        ws_target_url = "wss://" + target_url[8:]
    elif target_url.lower().startswith("http://"):
        ws_target_url = "ws://" + target_url[7:]
    else:
        ws_target_url = target_url
    
    # Add path if provided (but only if it's not already a complete URL)
    if path and not (path.lower().startswith("ws://") or path.lower().startswith("wss://")):
        # Join the target URL with the path properly
        ws_target_url = urljoin(ws_target_url.rstrip('/') + '/', path.lstrip('/'))
    
    try:
        # Connect to the target WebSocket server
        async with websockets.connect(ws_target_url) as target_ws:
            # Forward messages between client and target server
            async def forward_client_to_target():
                try:
                    while True:
                        data = await websocket.receive_text()
                        await target_ws.send(data)
                except WebSocketDisconnect:
                    pass
            
            async def forward_target_to_client():
                try:
                    while True:
                        data = await target_ws.recv()
                        await websocket.send_text(data)
                except websockets.ConnectionClosed:
                    pass
            
            # Run both forwarding tasks concurrently
            await asyncio.gather(
                forward_client_to_target(),
                forward_target_to_client(),
                return_exceptions=True
            )
    except websockets.InvalidURI:
        await websocket.close(code=1008, reason="Invalid WebSocket URL")
    except websockets.InvalidHandshake:
        await websocket.close(code=1008, reason="WebSocket handshake failed")
    except Exception as e:
        logger.error(f"Error in WebSocket connection: {str(e)}")
        await websocket.close(code=1011, reason="Internal server error")
    finally:
        try:
            await websocket.close()
        except:
            pass

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "ok"}

@app.on_event("shutdown")
async def shutdown_event():
    await http_client.close()

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)