yzhuang commited on
Commit
c2ec273
·
1 Parent(s): 3d9b062
Files changed (2) hide show
  1. app.py +4 -2
  2. server.py +17 -60
app.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import requests
4
  import sseclient
5
  import gradio as gr
6
- import server
7
 
8
  API_URL = "http://localhost:8000/v1/chat/completions"
9
 
@@ -53,7 +53,7 @@ def stream_completion(message, history, max_tokens, temperature, top_p, beta):
53
 
54
 
55
  # ----------------------- UI ---------------------------------------------
56
- with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo", teardown=lambda: _kill_proc_tree(server_proc)) as demo:
57
  gr.Markdown(
58
  "## 🎨 Mixture of Inputs (MoI) Demo \n"
59
  "Streaming vLLM demo with dynamic **beta** adjustment in MoI, higher beta means less blending."
@@ -80,4 +80,6 @@ with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo", teardown=lambda: _kill
80
  clear_btn.click(lambda: None, None, chatbot, queue=False)
81
 
82
  if __name__ == "__main__":
 
 
83
  demo.launch()
 
3
  import requests
4
  import sseclient
5
  import gradio as gr
6
+ from server import setup_mixinputs, launch_vllm_server
7
 
8
  API_URL = "http://localhost:8000/v1/chat/completions"
9
 
 
53
 
54
 
55
  # ----------------------- UI ---------------------------------------------
56
+ with gr.Blocks(title="🎨 Mixture of Inputs (MoI) Demo") as demo:
57
  gr.Markdown(
58
  "## 🎨 Mixture of Inputs (MoI) Demo \n"
59
  "Streaming vLLM demo with dynamic **beta** adjustment in MoI, higher beta means less blending."
 
80
  clear_btn.click(lambda: None, None, chatbot, queue=False)
81
 
82
  if __name__ == "__main__":
83
+ setup_mixinputs()
84
+ launch_vllm_server(beta=1.0)
85
  demo.launch()
server.py CHANGED
@@ -1,45 +1,20 @@
1
- # app.py ── launch vLLM inside a Hugging Face Space (with clean shutdown)
2
- import os, signal, sys, atexit, time, socket, subprocess
3
- import spaces # only needed for the GPU decorator
 
 
4
 
5
- # ----------------------------------------------------------------------
6
- # Helpers
7
- # ----------------------------------------------------------------------
8
- def _wait_for_port(host: str, port: int, timeout: int = 240):
9
- """Block until (host, port) accepts TCP connections or timeout."""
10
- deadline = time.time() + timeout
11
- while time.time() < deadline:
12
- with socket.socket() as sock:
13
- sock.settimeout(2)
14
- if sock.connect_ex((host, port)) == 0:
15
- return
16
- time.sleep(1)
17
- raise RuntimeError(f"vLLM server on {host}:{port} never came up")
18
-
19
- def _kill_proc_tree(proc: subprocess.Popen):
20
- """SIGTERM the whole process-group started by `proc` (if still alive)."""
21
- if proc and proc.poll() is None: # still running
22
- pgid = os.getpgid(proc.pid)
23
- os.killpg(pgid, signal.SIGTERM) # graceful
24
- try:
25
- proc.wait(15)
26
- except subprocess.TimeoutExpired:
27
- os.killpg(pgid, signal.SIGKILL) # force
28
-
29
- # ----------------------------------------------------------------------
30
- # Setup – runs on *CPU* only; fast.
31
- # ----------------------------------------------------------------------
32
  def setup_mixinputs():
 
33
  subprocess.run(["mixinputs", "setup"], check=True)
34
 
35
- # ----------------------------------------------------------------------
36
- # Serve runs on the GPU; heavy, so we mark it.
37
- # ----------------------------------------------------------------------
38
- def launch_vllm_server(beta: float = 1.0, port: int = 8000) -> subprocess.Popen:
39
  env = os.environ.copy()
40
  env["MIXINPUTS_BETA"] = str(beta)
41
  env["VLLM_USE_V1"] = "1"
42
 
 
43
  cmd = [
44
  "vllm", "serve",
45
  "Qwen/Qwen3-4B",
@@ -48,33 +23,15 @@ def launch_vllm_server(beta: float = 1.0, port: int = 8000) -> subprocess.Popen:
48
  "--max-model-len", "2048",
49
  "--max-seq-len-to-capture", "2048",
50
  "--max-num-seqs", "1",
51
- "--port", str(port)
52
  ]
 
53
 
54
- # new session its own process-group
55
- proc = subprocess.Popen(cmd, env=env, start_new_session=True)
56
- _wait_for_port("localhost", port) # block until ready
57
- return proc
58
-
59
- # ----------------------------------------------------------------------
60
- # MAIN
61
- # ----------------------------------------------------------------------
62
- if __name__ == "__main__":
63
- setup_mixinputs() # fast
64
- server_proc = launch_vllm_server() # heavy
65
 
66
- # Ensures the GPU process dies when the Space stops / reloads
67
- atexit.register(_kill_proc_tree, server_proc)
68
 
69
- # ---- your Gradio / FastAPI app goes below ----
70
- # e.g. import gradio as gr
71
- # with gr.Blocks(teardown=lambda: _kill_proc_tree(server_proc)) as demo:
72
- # ...
73
- # demo.launch(server_name="0.0.0.0", server_port=7860)
74
- #
75
- # For this snippet we’ll just block forever so the container
76
- # doesn’t exit immediately.
77
- try:
78
- server_proc.wait()
79
- except KeyboardInterrupt:
80
- pass
 
1
+ import subprocess
2
+ import threading
3
+ import os
4
+ import time
5
+ import spaces
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def setup_mixinputs():
8
+ # Step 1: Run mixinputs setup
9
  subprocess.run(["mixinputs", "setup"], check=True)
10
 
11
+ def launch_vllm_server(beta=1.0):
12
+ # Step 2: Set environment variables
 
 
13
  env = os.environ.copy()
14
  env["MIXINPUTS_BETA"] = str(beta)
15
  env["VLLM_USE_V1"] = "1"
16
 
17
+ # Step 3: Launch vLLM with custom options
18
  cmd = [
19
  "vllm", "serve",
20
  "Qwen/Qwen3-4B",
 
23
  "--max-model-len", "2048",
24
  "--max-seq-len-to-capture", "2048",
25
  "--max-num-seqs", "1",
26
+ "--port", "8000"
27
  ]
28
+ subprocess.run(cmd, env=env)
29
 
30
+ # # Step 1: Setup
31
+ # setup_mixinputs()
 
 
 
 
 
 
 
 
 
32
 
33
+ # # Step 2: Launch vLLM server in background
34
+ # threading.Thread(target=launch_vllm_server, daemon=True).start()
35
 
36
+ # # Step 3: Give time for server to initialize
37
+ # time.sleep(60)