ap3 / app.py
Reality123b's picture
Update app.py
1ea4540 verified
raw
history blame
6.78 kB
from fastapi import FastAPI
from fastapi.responses import StreamingResponse, HTMLResponse
from pydantic import BaseModel
from transformers import pipeline
import asyncio
import queue
import threading
import time
import random
import httpx
class ModelInput(BaseModel):
prompt: str
max_new_tokens: int = 64000
app = FastAPI()
# Your main generation model (DeepSeek)
generator = pipeline(
"text-generation",
model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
device="cpu"
)
# The summarization instruct model
summarizer = pipeline(
"text-generation",
model="HuggingFaceTB/SmolLM2-360M-Instruct",
device="cpu",
max_length=512, # keep summary short
do_sample=False
)
knowledge_graph = {}
async def fetch_ddg_search(query: str):
url = "https://api.duckduckgo.com/"
params = {
"q": query,
"format": "json",
"no_redirect": "1",
"no_html": "1",
"skip_disambig": "1"
}
async with httpx.AsyncClient() as client:
resp = await client.get(url, params=params, timeout=15)
data = resp.json()
return data
def clean_ddg_text(ddg_json):
# Take abstract text + top related topic texts concatenated
abstract = ddg_json.get("AbstractText", "")
related = ddg_json.get("RelatedTopics", [])
related_texts = []
for item in related:
if "Text" in item:
related_texts.append(item["Text"])
elif "Name" in item and "Topics" in item:
for sub in item["Topics"]:
if "Text" in sub:
related_texts.append(sub["Text"])
combined_text = abstract + " " + " ".join(related_texts)
# Simple clean up, trim length to avoid overloading
combined_text = combined_text.strip().replace("\n", " ")
if len(combined_text) > 1000:
combined_text = combined_text[:1000] + "..."
return combined_text
def summarize_text(text: str):
# Run the instruct model to summarize/clean the text
prompt = f"Summarize this information concisely:\n{text}\nSummary:"
output = summarizer(prompt, max_length=256, do_sample=False)
return output[0]["generated_text"].strip()
async def update_knowledge_graph_periodically():
queries = [
"latest tech startup news",
"AI breakthroughs 2025",
"funding trends in tech startups",
"popular programming languages 2025",
"open source AI models"
]
while True:
query = random.choice(queries)
print(f"[KG Updater] Searching DuckDuckGo for query: {query}")
try:
ddg_data = await fetch_ddg_search(query)
cleaned = clean_ddg_text(ddg_data)
if not cleaned:
cleaned = "No useful info found."
print(f"[KG Updater] DuckDuckGo cleaned text length: {len(cleaned)}")
# Summarize using your instruct model in a thread (blocking)
loop = asyncio.get_event_loop()
summary = await loop.run_in_executor(None, summarize_text, cleaned)
print(f"[KG Updater] Summary length: {len(summary)}")
knowledge_graph[query] = {
"raw_text": cleaned,
"summary": summary,
"timestamp": time.time()
}
print(f"[KG Updater] Knowledge graph updated for query: {query}")
except Exception as e:
print(f"[KG Updater] Error: {e}")
await asyncio.sleep(60)
@app.on_event("startup")
async def startup_event():
asyncio.create_task(update_knowledge_graph_periodically())
# Manual streaming endpoint (kept as-is)
@app.post("/generate/stream")
async def generate_stream(input: ModelInput):
q = queue.Queue()
def run_generation():
try:
streamer = pipeline("text-generation", model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", device="cpu").tokenizer
streamer = TextStreamer(generator.tokenizer, skip_prompt=True)
def enqueue_token(token):
q.put(token)
streamer.put = enqueue_token
generator(
input.prompt,
max_new_tokens=input.max_new_tokens,
do_sample=False,
streamer=streamer
)
except Exception as e:
q.put(f"[ERROR] {e}")
finally:
q.put(None)
thread = threading.Thread(target=run_generation)
thread.start()
async def event_generator():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, q.get)
if token is None:
break
yield token
return StreamingResponse(event_generator(), media_type="text/plain")
# Endpoint to get KG
@app.get("/knowledge")
async def get_knowledge():
return knowledge_graph
# Basic client page to test streaming
@app.get("/", response_class=HTMLResponse)
async def root():
return """
<!DOCTYPE html>
<html>
<head><title>Streaming Text Generation Client</title></head>
<body>
<h2>Streaming Text Generation Demo</h2>
<textarea id="prompt" rows="4" cols="60">Write me a poem about tech startup struggles</textarea><br/>
<button onclick="startStreaming()">Generate</button>
<pre id="output" style="white-space: pre-wrap; background:#eee; padding:10px; border-radius:5px; max-height:400px; overflow:auto;"></pre>
<h3>Knowledge Graph</h3>
<pre id="kg" style="background:#ddd; padding:10px; max-height:300px; overflow:auto;"></pre>
<script>
async function startStreaming() {
const prompt = document.getElementById("prompt").value;
const output = document.getElementById("output");
output.textContent = "";
const response = await fetch("/generate/stream", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ prompt: prompt, max_new_tokens: 64000 })
});
const reader = response.body.getReader();
const decoder = new TextDecoder();
while(true) {
const {done, value} = await reader.read();
if(done) break;
const chunk = decoder.decode(value, {stream: true});
output.textContent += chunk;
output.scrollTop = output.scrollHeight;
}
}
async function fetchKG() {
const kgPre = document.getElementById("kg");
const res = await fetch("/knowledge");
const data = await res.json();
kgPre.textContent = JSON.stringify(data, null, 2);
}
setInterval(fetchKG, 10000); // update KG display every 10s
window.onload = fetchKG;
</script>
</body>
</html>
"""