Spaces:
Build error
Build error
import streamlit as st | |
import os | |
import time | |
import gc | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
from typing import Dict, List, TypedDict | |
from langgraph.graph import StateGraph, END | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
AGENT_MODEL_CONFIG = { | |
"product_manager": { | |
"base_id": "unsloth/mistral-7b-bnb-4bit", | |
"adapter_id": "spandana30/product-manager-mistral" | |
}, | |
"project_manager": { | |
"base_id": "unsloth/gemma-3-1b-it", | |
"adapter_id": "spandana30/project-manager-gemma" | |
}, | |
"designer": { | |
"base_id": "unsloth/gemma-3-1b-it", | |
"adapter_id": "spandana30/project-manager-gemma" | |
}, | |
"software_engineer": { | |
"base_id": "codellama/CodeLLaMA-7b-hf", | |
"adapter_id": "spandana30/software-engineer-codellama" | |
}, | |
"qa_engineer": { | |
"base_id": "codellama/CodeLLaMA-7b-hf", | |
"adapter_id": "spandana30/software-engineer-codellama" | |
} | |
} | |
def load_agent_model(base_id, adapter_id): | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_id, | |
torch_dtype=torch.float16, | |
device_map="auto", | |
load_in_4bit=True, | |
token=HF_TOKEN | |
) | |
model = PeftModel.from_pretrained(base_model, adapter_id, token=HF_TOKEN) | |
tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=HF_TOKEN) | |
return model.eval(), tokenizer | |
def call_model(prompt: str, model, tokenizer) -> str: | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=1024, | |
do_sample=False, | |
temperature=0.3 | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
class AgentState(TypedDict): | |
messages: List[Dict[str, str]] | |
product_vision: str | |
project_plan: str | |
design_specs: str | |
html: str | |
feedback: str | |
iteration: int | |
done: bool | |
timings: Dict[str, float] | |
def agent(template: str, state: AgentState, agent_key: str, timing_label: str): | |
st.write(f'π Running agent: {agent_key}') | |
start = time.time() | |
model, tokenizer = load_agent_model(**AGENT_MODEL_CONFIG[agent_key]) | |
prompt = template.format( | |
user_request=state["messages"][0]["content"], | |
product_vision=state.get("product_vision", ""), | |
project_plan=state.get("project_plan", ""), | |
design_specs=state.get("design_specs", ""), | |
html=state.get("html", "") | |
) | |
st.write(f'π€ Prompt for {agent_key}:', prompt) | |
response = call_model(prompt, model, tokenizer) | |
st.write(f'π₯ Response from {agent_key}:', response[:500]) | |
state["messages"].append({"role": agent_key, "content": response}) | |
state["timings"][timing_label] = time.time() - start | |
gc.collect() | |
return response | |
PROMPTS = { | |
"product_manager": ( | |
"You're a Product Manager. Interpret this user request:\n" | |
"{user_request}\n" | |
"Define the high-level product goals, features, and user stories." | |
), | |
"project_manager": ( | |
"You're a Project Manager. Based on this feature list:\n" | |
"{product_vision}\n" | |
"Create a project plan with key milestones and task assignments." | |
), | |
"designer": ( | |
"You're a UI designer. Create design specs for:\n" | |
"{project_plan}\n" | |
"Include:\n" | |
"1. Color palette (primary, secondary, accent)\n" | |
"2. Font choices\n" | |
"3. Layout structure\n" | |
"4. Component styles\n" | |
"Don't write code - just design guidance." | |
), | |
"software_engineer": ( | |
"Create a complete HTML page with embedded CSS for:\n" | |
"{design_specs}\n" | |
"Requirements:\n" | |
"1. Full HTML document with <!DOCTYPE>\n" | |
"2. CSS inside <style> tags in head\n" | |
"3. Mobile-responsive\n" | |
"4. Semantic HTML\n" | |
"5. Ready-to-use (will work when saved as .html)\n" | |
"Output JUST the complete HTML file content:" | |
), | |
"qa_engineer": ( | |
"Review this website:\n" | |
"{html}\n" | |
"Check for:\n" | |
"1. Visual quality\n" | |
"2. Responsiveness\n" | |
"3. Functionality\n" | |
"Reply \"APPROVED\" if perfect, or suggest improvements." | |
) | |
} | |
def generate_ui(user_prompt: str, max_iter: int): | |
state: AgentState = { | |
"messages": [{"role": "user", "content": user_prompt}], | |
"product_vision": "", | |
"project_plan": "", | |
"design_specs": "", | |
"html": "", | |
"feedback": "", | |
"iteration": 0, | |
"done": False, | |
"timings": {} | |
} | |
workflow = StateGraph(AgentState) | |
workflow.add_node("product_manager", lambda s: { | |
"messages": s["messages"] + [{ | |
"role": "product_manager", | |
"content": (pv := agent(PROMPTS["product_manager"], s, "product_manager", "product_manager")) | |
}], | |
"product_vision": pv | |
}) | |
workflow.add_node("project_manager", lambda s: { | |
"messages": s["messages"] + [{ | |
"role": "project_manager", | |
"content": (pp := agent(PROMPTS["project_manager"], s, "project_manager", "project_manager")) | |
}], | |
"project_plan": pp | |
}) | |
workflow.add_node("designer", lambda s: { | |
"messages": s["messages"] + [{ | |
"role": "designer", | |
"content": (ds := agent(PROMPTS["designer"], s, "designer", "designer")) | |
}], | |
"design_specs": ds | |
}) | |
workflow.add_node("software_engineer", lambda s: { | |
"html": (html := agent(PROMPTS["software_engineer"], s, "software_engineer", "software_engineer")), | |
"messages": s["messages"] + [{"role": "software_engineer", "content": html}] | |
}) | |
def qa_fn(s): | |
feedback = agent(PROMPTS["qa_engineer"], s, "qa_engineer", "qa_engineer") | |
done = "APPROVED" in feedback or s["iteration"] >= max_iter | |
return { | |
"feedback": feedback, | |
"done": done, | |
"iteration": s["iteration"] + 1, | |
"messages": s["messages"] + [{"role": "qa_engineer", "content": feedback}] | |
} | |
workflow.add_node("qa_engineer", qa_fn) | |
workflow.add_edge("product_manager", "project_manager") | |
workflow.add_edge("project_manager", "designer") | |
workflow.add_edge("designer", "software_engineer") | |
workflow.add_edge("software_engineer", "qa_engineer") | |
workflow.add_conditional_edges("qa_engineer", lambda s: END if s["done"] else "software_engineer") | |
workflow.set_entry_point("product_manager") | |
app = workflow.compile() | |
final_state = app.invoke(state) | |
return final_state | |
def main(): | |
st.set_page_config(page_title="Multi-Agent UI Generator", layout="wide") | |
st.title("π§ Multi-Agent UI Generation System") | |
max_iter = st.sidebar.slider("Max QA Iterations", 1, 5, 2) | |
prompt = st.text_area("What UI do you want to build?", "A coffee shop landing page with a hero image, menu, and contact form.", height=150) | |
if st.button("π Generate"): | |
with st.spinner("Agents working..."): | |
final = generate_ui(prompt, max_iter) | |
st.success("β UI Generated") | |
st.subheader("π Final Output") | |
st.components.v1.html(final["html"], height=600, scrolling=True) | |
st.subheader("π§ Agent Messages") | |
for msg in final["messages"]: | |
st.markdown(f"**{msg['role'].title()}**:\n```\n{msg['content']}\n```") | |
if __name__ == "__main__": | |
main() |