wip-test / app.py
rifatramadhani's picture
wip
d067c3d
raw
history blame
9.13 kB
from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.concurrency import run_in_threadpool
from gradio_client import Client
import gradio as gr
import uvicorn
import httpx
import websockets
import asyncio
from urllib.parse import urljoin, urlparse, unquote
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
class HttpClient:
def __init__(self):
# Configure the HTTP client with appropriate timeouts
self.client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0),
follow_redirects=False
)
async def forward_request(self, request: Request, target_url: str):
"""
Forward an incoming request to a target URL
"""
try:
# Extract method, headers, and body from the incoming request
method = request.method
headers = dict(request.headers)
# Remove headers that shouldn't be forwarded
headers.pop("host", None)
headers.pop("connection", None)
# Get the request body
body = await request.body()
logger.info(f"Forwarding {method} request to {target_url}")
# Forward the request to the target URL
response = await self.client.request(
method=method,
url=target_url,
headers=headers,
content=body
)
# Handle the response from the target server
response_headers = dict(response.headers)
# Remove headers that shouldn't be forwarded from the response
response_headers.pop("connection", None)
response_headers.pop("transfer-encoding", None)
return Response(
content=response.content,
status_code=response.status_code,
headers=response_headers
)
except httpx.TimeoutException:
logger.error(f"Timeout error while forwarding request to {target_url}")
return Response(
content="Request timeout error",
status_code=504
)
except httpx.NetworkError as e:
logger.error(f"Network error while forwarding request: {str(e)}")
return Response(
content=f"Network error: {str(e)}",
status_code=502
)
except Exception as e:
logger.error(f"Error forwarding request: {str(e)}")
return Response(
content=f"Request error: {str(e)}",
status_code=500
)
async def close(self):
await self.client.aclose()
# Initialize the HTTP client
http_client = HttpClient()
@app.get("/", response_class=HTMLResponse)
async def read_root():
with open("index.html") as f:
return f.read()
@app.post("/gp/{repo_id:path}/{api_name:path}")
async def gradio_client(repo_id: str, api_name: str, request: Request):
client = Client(repo_id)
data = await request.json()
result = await run_in_threadpool(client.predict, *data["args"], api_name=f"/{api_name}")
return result
@app.api_route("/wp/{path:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"])
async def web_client(request: Request, path: str):
"""
Main web client endpoint that forwards all requests to the target URL
specified in the path or the 'X-Target-Url' header
"""
# Prioritize URL in path if it starts with http:// or https://
if path.lower().startswith("http://") or path.lower().startswith("https://"):
target_url = path
else:
# Get the target URL from the header
target_url = request.headers.get("X-Target-Url")
# If we have a target URL from header and a path, combine them
if target_url and path:
# Validate the target URL from header
try:
parsed_url = urlparse(target_url)
if not parsed_url.scheme or not parsed_url.netloc:
return Response(
content="Invalid X-Target-Url header",
status_code=400
)
except Exception:
return Response(
content="Invalid X-Target-Url header",
status_code=400
)
# Join the target URL with the path properly
target_url = urljoin(target_url.rstrip('/') + '/', path.lstrip('/'))
if not target_url:
return Response(
content="Missing X-Target-Url header or URL in path",
status_code=400
)
# Validate the target URL
try:
parsed_url = urlparse(target_url)
if not parsed_url.scheme or not parsed_url.netloc:
return Response(
content="Invalid X-Target-Url header or URL in path",
status_code=400
)
except Exception:
return Response(
content="Invalid X-Target-Url header or URL in path",
status_code=400
)
# Forward the request
return await http_client.forward_request(request, target_url)
@app.websocket("/wp/{path:path}")
async def websocket_client(websocket: WebSocket, path: str):
"""
WebSocket endpoint that forwards WebSocket connections to the target URL
specified in the 'X-Target-Url' header or in the path
"""
# Get the target URL from the header or path
target_url = websocket.headers.get("X-Target-Url")
# If no header, use path as target URL if it's a valid WebSocket URL
if not target_url:
# Handle URL-encoded paths
decoded_path = path
if path and '%' in path:
# URL decode the path
decoded_path = unquote(path)
if decoded_path and (decoded_path.lower().startswith("ws://") or decoded_path.lower().startswith("wss://")):
target_url = decoded_path
else:
await websocket.close(code=1008, reason="Missing X-Target-Url header or invalid URL in path")
return
# Validate the target URL
try:
parsed_url = urlparse(target_url)
if not parsed_url.scheme or not parsed_url.netloc:
await websocket.close(code=1008, reason="Invalid target URL")
return
except Exception:
await websocket.close(code=1008, reason="Invalid target URL")
return
# Accept the WebSocket connection
await websocket.accept()
# Convert HTTP/HTTPS URL to WebSocket URL
if target_url.lower().startswith("https://"):
ws_target_url = "wss://" + target_url[8:]
elif target_url.lower().startswith("http://"):
ws_target_url = "ws://" + target_url[7:]
else:
ws_target_url = target_url
# Add path if provided (but only if it's not already a complete URL)
if path and not (path.lower().startswith("ws://") or path.lower().startswith("wss://")):
# Join the target URL with the path properly
ws_target_url = urljoin(ws_target_url.rstrip('/') + '/', path.lstrip('/'))
try:
# Connect to the target WebSocket server
async with websockets.connect(ws_target_url) as target_ws:
# Forward messages between client and target server
async def forward_client_to_target():
try:
while True:
data = await websocket.receive_text()
await target_ws.send(data)
except WebSocketDisconnect:
pass
async def forward_target_to_client():
try:
while True:
data = await target_ws.recv()
await websocket.send_text(data)
except websockets.ConnectionClosed:
pass
# Run both forwarding tasks concurrently
await asyncio.gather(
forward_client_to_target(),
forward_target_to_client(),
return_exceptions=True
)
except websockets.InvalidURI:
await websocket.close(code=1008, reason="Invalid WebSocket URL")
except websockets.InvalidHandshake:
await websocket.close(code=1008, reason="WebSocket handshake failed")
except Exception as e:
logger.error(f"Error in WebSocket connection: {str(e)}")
await websocket.close(code=1011, reason="Internal server error")
finally:
try:
await websocket.close()
except:
pass
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {"status": "ok"}
@app.on_event("shutdown")
async def shutdown_event():
await http_client.close()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)