Reality123b commited on
Commit
1ea4540
·
verified ·
1 Parent(s): 2145ed0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -57
app.py CHANGED
@@ -1,120 +1,200 @@
1
- from fastapi import FastAPI, Request, HTTPException
2
- from fastapi.responses import StreamingResponse
3
  from pydantic import BaseModel
4
- from transformers import pipeline, TextStreamer
5
  import asyncio
6
  import queue
7
  import threading
8
  import time
 
9
  import httpx
10
- import json
11
 
12
  class ModelInput(BaseModel):
13
  prompt: str
14
- max_new_tokens: int = 128
15
 
16
  app = FastAPI()
17
 
18
- # Initialize generator once
19
  generator = pipeline(
20
  "text-generation",
21
  model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
22
  device="cpu"
23
  )
24
 
25
- # Shared knowledge graph, just a dict (in-memory)
 
 
 
 
 
 
 
 
26
  knowledge_graph = {}
27
 
28
- # --- Autonomous knowledge updater --- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  async def update_knowledge_graph_periodically():
 
 
 
 
 
 
 
30
  while True:
 
 
31
  try:
32
- # Pick a random query (here: hardcoded or you can improve)
33
- queries = ["latest tech startup news", "AI breakthroughs", "funding trends 2025"]
34
- import random
35
- query = random.choice(queries)
36
-
37
- # Use DuckDuckGo Instant Answer API (free, no API key)
38
- async with httpx.AsyncClient() as client:
39
- resp = await client.get(
40
- "https://api.duckduckgo.com/",
41
- params={"q": query, "format": "json", "no_redirect": "1", "no_html": "1"}
42
- )
43
- data = resp.json()
44
-
45
- # Extract some useful info (abstract text)
46
- abstract = data.get("AbstractText", "")
47
- related_topics = data.get("RelatedTopics", [])
48
-
49
- # Save/update knowledge graph (super basic example)
50
  knowledge_graph[query] = {
51
- "abstract": abstract,
52
- "related_topics": related_topics,
53
  "timestamp": time.time()
54
  }
55
-
56
- print(f"Knowledge graph updated for query: {query}")
57
 
58
  except Exception as e:
59
- print(f"Error updating knowledge graph: {e}")
60
 
61
- await asyncio.sleep(60) # wait 1 minute before next update
62
 
63
- # Kick off background task on startup
64
  @app.on_event("startup")
65
  async def startup_event():
66
  asyncio.create_task(update_knowledge_graph_periodically())
67
 
68
- # --- Streaming generation endpoint --- #
69
  @app.post("/generate/stream")
70
  async def generate_stream(input: ModelInput):
71
- prompt = input.prompt
72
- max_new_tokens = input.max_new_tokens
73
-
74
  q = queue.Queue()
75
-
76
  def run_generation():
77
  try:
 
78
  streamer = TextStreamer(generator.tokenizer, skip_prompt=True)
79
-
80
- # Monkey-patch streamer to push tokens to queue
81
- def queue_token(token):
82
  q.put(token)
83
-
84
- streamer.put = queue_token
85
-
86
- # Run generation with streamer attached
87
  generator(
88
- prompt,
89
- max_new_tokens=max_new_tokens,
90
  do_sample=False,
91
  streamer=streamer
92
  )
93
  except Exception as e:
94
  q.put(f"[ERROR] {e}")
95
  finally:
96
- q.put(None) # Sentinel to mark done
97
-
98
  thread = threading.Thread(target=run_generation)
99
  thread.start()
100
 
101
  async def event_generator():
 
102
  while True:
103
- token = q.get()
104
  if token is None:
105
  break
106
  yield token
107
-
108
  return StreamingResponse(event_generator(), media_type="text/plain")
109
 
110
-
111
- # Optional: Endpoint to query knowledge graph
112
  @app.get("/knowledge")
113
  async def get_knowledge():
114
  return knowledge_graph
115
 
116
-
117
- # Root
118
- @app.get("/")
119
  async def root():
120
- return {"message": "Welcome to the Streaming Model API with live knowledge graph updater!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import StreamingResponse, HTMLResponse
3
  from pydantic import BaseModel
4
+ from transformers import pipeline
5
  import asyncio
6
  import queue
7
  import threading
8
  import time
9
+ import random
10
  import httpx
 
11
 
12
  class ModelInput(BaseModel):
13
  prompt: str
14
+ max_new_tokens: int = 64000
15
 
16
  app = FastAPI()
17
 
18
+ # Your main generation model (DeepSeek)
19
  generator = pipeline(
20
  "text-generation",
21
  model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
22
  device="cpu"
23
  )
24
 
25
+ # The summarization instruct model
26
+ summarizer = pipeline(
27
+ "text-generation",
28
+ model="HuggingFaceTB/SmolLM2-360M-Instruct",
29
+ device="cpu",
30
+ max_length=512, # keep summary short
31
+ do_sample=False
32
+ )
33
+
34
  knowledge_graph = {}
35
 
36
+ async def fetch_ddg_search(query: str):
37
+ url = "https://api.duckduckgo.com/"
38
+ params = {
39
+ "q": query,
40
+ "format": "json",
41
+ "no_redirect": "1",
42
+ "no_html": "1",
43
+ "skip_disambig": "1"
44
+ }
45
+ async with httpx.AsyncClient() as client:
46
+ resp = await client.get(url, params=params, timeout=15)
47
+ data = resp.json()
48
+ return data
49
+
50
+ def clean_ddg_text(ddg_json):
51
+ # Take abstract text + top related topic texts concatenated
52
+ abstract = ddg_json.get("AbstractText", "")
53
+ related = ddg_json.get("RelatedTopics", [])
54
+ related_texts = []
55
+ for item in related:
56
+ if "Text" in item:
57
+ related_texts.append(item["Text"])
58
+ elif "Name" in item and "Topics" in item:
59
+ for sub in item["Topics"]:
60
+ if "Text" in sub:
61
+ related_texts.append(sub["Text"])
62
+ combined_text = abstract + " " + " ".join(related_texts)
63
+ # Simple clean up, trim length to avoid overloading
64
+ combined_text = combined_text.strip().replace("\n", " ")
65
+ if len(combined_text) > 1000:
66
+ combined_text = combined_text[:1000] + "..."
67
+ return combined_text
68
+
69
+ def summarize_text(text: str):
70
+ # Run the instruct model to summarize/clean the text
71
+ prompt = f"Summarize this information concisely:\n{text}\nSummary:"
72
+ output = summarizer(prompt, max_length=256, do_sample=False)
73
+ return output[0]["generated_text"].strip()
74
+
75
  async def update_knowledge_graph_periodically():
76
+ queries = [
77
+ "latest tech startup news",
78
+ "AI breakthroughs 2025",
79
+ "funding trends in tech startups",
80
+ "popular programming languages 2025",
81
+ "open source AI models"
82
+ ]
83
  while True:
84
+ query = random.choice(queries)
85
+ print(f"[KG Updater] Searching DuckDuckGo for query: {query}")
86
  try:
87
+ ddg_data = await fetch_ddg_search(query)
88
+ cleaned = clean_ddg_text(ddg_data)
89
+ if not cleaned:
90
+ cleaned = "No useful info found."
91
+ print(f"[KG Updater] DuckDuckGo cleaned text length: {len(cleaned)}")
92
+
93
+ # Summarize using your instruct model in a thread (blocking)
94
+ loop = asyncio.get_event_loop()
95
+ summary = await loop.run_in_executor(None, summarize_text, cleaned)
96
+ print(f"[KG Updater] Summary length: {len(summary)}")
97
+
 
 
 
 
 
 
 
98
  knowledge_graph[query] = {
99
+ "raw_text": cleaned,
100
+ "summary": summary,
101
  "timestamp": time.time()
102
  }
103
+ print(f"[KG Updater] Knowledge graph updated for query: {query}")
 
104
 
105
  except Exception as e:
106
+ print(f"[KG Updater] Error: {e}")
107
 
108
+ await asyncio.sleep(60)
109
 
 
110
  @app.on_event("startup")
111
  async def startup_event():
112
  asyncio.create_task(update_knowledge_graph_periodically())
113
 
114
+ # Manual streaming endpoint (kept as-is)
115
  @app.post("/generate/stream")
116
  async def generate_stream(input: ModelInput):
 
 
 
117
  q = queue.Queue()
 
118
  def run_generation():
119
  try:
120
+ streamer = pipeline("text-generation", model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", device="cpu").tokenizer
121
  streamer = TextStreamer(generator.tokenizer, skip_prompt=True)
122
+ def enqueue_token(token):
 
 
123
  q.put(token)
124
+ streamer.put = enqueue_token
 
 
 
125
  generator(
126
+ input.prompt,
127
+ max_new_tokens=input.max_new_tokens,
128
  do_sample=False,
129
  streamer=streamer
130
  )
131
  except Exception as e:
132
  q.put(f"[ERROR] {e}")
133
  finally:
134
+ q.put(None)
 
135
  thread = threading.Thread(target=run_generation)
136
  thread.start()
137
 
138
  async def event_generator():
139
+ loop = asyncio.get_event_loop()
140
  while True:
141
+ token = await loop.run_in_executor(None, q.get)
142
  if token is None:
143
  break
144
  yield token
 
145
  return StreamingResponse(event_generator(), media_type="text/plain")
146
 
147
+ # Endpoint to get KG
 
148
  @app.get("/knowledge")
149
  async def get_knowledge():
150
  return knowledge_graph
151
 
152
+ # Basic client page to test streaming
153
+ @app.get("/", response_class=HTMLResponse)
 
154
  async def root():
155
+ return """
156
+ <!DOCTYPE html>
157
+ <html>
158
+ <head><title>Streaming Text Generation Client</title></head>
159
+ <body>
160
+ <h2>Streaming Text Generation Demo</h2>
161
+ <textarea id="prompt" rows="4" cols="60">Write me a poem about tech startup struggles</textarea><br/>
162
+ <button onclick="startStreaming()">Generate</button>
163
+ <pre id="output" style="white-space: pre-wrap; background:#eee; padding:10px; border-radius:5px; max-height:400px; overflow:auto;"></pre>
164
+ <h3>Knowledge Graph</h3>
165
+ <pre id="kg" style="background:#ddd; padding:10px; max-height:300px; overflow:auto;"></pre>
166
+
167
+ <script>
168
+ async function startStreaming() {
169
+ const prompt = document.getElementById("prompt").value;
170
+ const output = document.getElementById("output");
171
+ output.textContent = "";
172
+ const response = await fetch("/generate/stream", {
173
+ method: "POST",
174
+ headers: { "Content-Type": "application/json" },
175
+ body: JSON.stringify({ prompt: prompt, max_new_tokens: 64000 })
176
+ });
177
+ const reader = response.body.getReader();
178
+ const decoder = new TextDecoder();
179
+ while(true) {
180
+ const {done, value} = await reader.read();
181
+ if(done) break;
182
+ const chunk = decoder.decode(value, {stream: true});
183
+ output.textContent += chunk;
184
+ output.scrollTop = output.scrollHeight;
185
+ }
186
+ }
187
+
188
+ async function fetchKG() {
189
+ const kgPre = document.getElementById("kg");
190
+ const res = await fetch("/knowledge");
191
+ const data = await res.json();
192
+ kgPre.textContent = JSON.stringify(data, null, 2);
193
+ }
194
+
195
+ setInterval(fetchKG, 10000); // update KG display every 10s
196
+ window.onload = fetchKG;
197
+ </script>
198
+ </body>
199
+ </html>
200
+ """