Spaces:
Build error
Build error
# Updated multi-agent UI generation system with custom fine-tuned LoRA adapters | |
import streamlit as st | |
import time | |
import base64 | |
from typing import Dict, List, TypedDict | |
from langgraph.graph import StateGraph, END | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from peft import PeftModel, PeftConfig | |
import torch | |
st.set_page_config(page_title="Multi-Agent Collaboration", layout="wide") | |
# Agent model loading config | |
AGENT_MODEL_CONFIG = { | |
"product_manager": { | |
"base": "mistralai/Mistral-7B-Instruct-v0.2", | |
"adapter": "spandana30/product-manager-mistral" | |
}, | |
"project_manager": { | |
"base": "google/gemma-1.1-7b-it", | |
"adapter": "spandana30/project-manager-gemma" | |
}, | |
"software_architect": { | |
"base": "cohere/command-r", # update if you have a local base version | |
"adapter": "spandana30/software-architect-cohere" | |
}, | |
"software_engineer": { | |
"base": "codellama/CodeLlama-7b-Instruct-hf", | |
"adapter": "spandana30/software-engineer-codellama" | |
}, | |
"qa": { | |
"base": "codellama/CodeLlama-7b-Instruct-hf", | |
"adapter": "spandana30/software-engineer-codellama" | |
}, | |
} | |
def load_agent_model(base_id, adapter_id): | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_id, torch_dtype=torch.float16, device_map="auto" | |
) | |
model = PeftModel.from_pretrained(base_model, adapter_id) | |
tokenizer = AutoTokenizer.from_pretrained(adapter_id) | |
return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=1024) | |
AGENT_PIPELINES = { | |
role: load_agent_model(cfg["base"], cfg["adapter"]) | |
for role, cfg in AGENT_MODEL_CONFIG.items() | |
} | |
class AgentState(TypedDict): | |
messages: List[Dict[str, str]] | |
user_request: str | |
refined_request: str | |
scoped_request: str | |
design_specs: str | |
html: str | |
feedback: str | |
iteration: int | |
done: bool | |
timings: Dict[str, float] | |
def run_pipeline(role: str, prompt: str): | |
response = AGENT_PIPELINES[role](prompt, do_sample=False)[0]['generated_text'] | |
return response.strip() | |
PROMPTS = { | |
"product_manager": """You're a Product Manager. Refine and clarify this request: | |
{user_request} | |
Ensure it's clear, feasible, and user-focused. Output the revised request only.""", | |
"project_manager": """You're a Project Manager. Given this refined request: | |
{refined_request} | |
Break it down into scope and constraints. Output the scoped request only.""", | |
"designer": """You're a UI designer. Create design specs for: | |
{scoped_request} | |
Include color palette, font, layout, and component styles. No code.""", | |
"software_engineer": """Create a full HTML page with embedded CSS for: | |
{design_specs} | |
Requirements: | |
- Semantic, responsive HTML | |
- Embedded CSS in <style> tag | |
- Output complete HTML only.""", | |
"qa": """Review this webpage: | |
{html} | |
Is it visually appealing, responsive, and functional? Reply "APPROVED" or suggest improvements.""" | |
} | |
def time_agent(agent_func, state: AgentState, label: str): | |
start = time.time() | |
result = agent_func(state) | |
result["timings"] = state["timings"] | |
result["timings"][label] = time.time() - start | |
return result | |
def product_manager_agent(state: AgentState): | |
revised = run_pipeline("product_manager", PROMPTS["product_manager"].format(user_request=state["user_request"])) | |
return {"refined_request": revised, "messages": state["messages"] + [{"role": "product_manager", "content": revised}]} | |
def project_manager_agent(state: AgentState): | |
scoped = run_pipeline("project_manager", PROMPTS["project_manager"].format(refined_request=state["refined_request"])) | |
return {"scoped_request": scoped, "messages": state["messages"] + [{"role": "project_manager", "content": scoped}]} | |
def designer_agent(state: AgentState): | |
specs = run_pipeline("product_manager", PROMPTS["designer"].format(scoped_request=state["scoped_request"])) | |
return {"design_specs": specs, "messages": state["messages"] + [{"role": "designer", "content": specs}]} | |
def engineer_agent(state: AgentState): | |
html = run_pipeline("software_engineer", PROMPTS["software_engineer"].format(design_specs=state["design_specs"])) | |
return {"html": html, "messages": state["messages"] + [{"role": "software_engineer", "content": html}]} | |
def qa_agent(state: AgentState, max_iter: int): | |
feedback = run_pipeline("qa", PROMPTS["qa"].format(html=state["html"])) | |
done = "APPROVED" in feedback or state["iteration"] >= max_iter | |
return {"feedback": feedback, "done": done, "iteration": state["iteration"] + 1, | |
"messages": state["messages"] + [{"role": "qa", "content": feedback}]} | |
def generate_ui(user_request: str, max_iter: int): | |
state = {"messages": [{"role": "user", "content": user_request}], | |
"user_request": user_request, | |
"refined_request": "", "scoped_request": "", "design_specs": "", | |
"html": "", "feedback": "", "iteration": 0, "done": False, "timings": {}} | |
workflow = StateGraph(AgentState) | |
workflow.add_node("product_manager", lambda s: time_agent(product_manager_agent, s, "product_manager")) | |
workflow.add_node("project_manager", lambda s: time_agent(project_manager_agent, s, "project_manager")) | |
workflow.add_node("designer", lambda s: time_agent(designer_agent, s, "designer")) | |
workflow.add_node("software_engineer", lambda s: time_agent(engineer_agent, s, "software_engineer")) | |
workflow.add_node("qa", lambda s: time_agent(lambda x: qa_agent(x, max_iter), s, "qa")) | |
workflow.add_edge("product_manager", "project_manager") | |
workflow.add_edge("project_manager", "designer") | |
workflow.add_edge("designer", "software_engineer") | |
workflow.add_edge("software_engineer", "qa") | |
workflow.add_conditional_edges("qa", lambda s: END if s["done"] else "software_engineer") | |
workflow.set_entry_point("product_manager") | |
app = workflow.compile() | |
total_start = time.time() | |
final_state = app.invoke(state) | |
return final_state["html"], final_state, time.time() - total_start | |
def main(): | |
st.title("π€ Multi-Agent UI Generator") | |
with st.sidebar: | |
max_iter = st.slider("Max QA Iterations", 1, 5, 2) | |
prompt = st.text_area("π Describe the UI you want:", "A coffee shop landing page with hero, menu, and contact form.", height=150) | |
if st.button("π Generate UI"): | |
with st.spinner("Agents working..."): | |
html, final_state, total_time = generate_ui(prompt, max_iter) | |
st.success("β UI Generated Successfully!") | |
st.components.v1.html(html, height=600, scrolling=True) | |
b64 = base64.b64encode(html.encode()).decode() | |
st.markdown(f'<a href="data:file/html;base64,{b64}" download="ui.html">π₯ Download HTML</a>', unsafe_allow_html=True) | |
st.subheader("π§ Agent Communication Log") | |
history_text = "" | |
for msg in final_state["messages"]: | |
role = msg["role"].replace("_", " ").title() | |
content = msg["content"] | |
history_text += f"---\n{role}:\n{content}\n\n" | |
st.text_area("Agent Dialogue", value=history_text, height=300) | |
b64_hist = base64.b64encode(history_text.encode()).decode() | |
st.markdown( | |
f'<a href="data:file/txt;base64,{b64_hist}" download="agent_communication.txt">π₯ Download Communication Log</a>', | |
unsafe_allow_html=True) | |
st.subheader("π Performance") | |
st.write(f"β±οΈ Total Time: {total_time:.2f} seconds") | |
st.write(f"π Iterations: {final_state['iteration']}") | |
for stage in ["product_manager", "project_manager", "designer", "software_engineer", "qa"]: | |
st.write(f"π§© {stage.replace('_', ' ').title()} Time: {final_state['timings'].get(stage, 0):.2f}s") | |
if __name__ == "__main__": | |
main() | |