Spaces:
Build error
Build error
import streamlit as st | |
import os | |
import time | |
import gc | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
from typing import Dict, List, TypedDict | |
from langgraph.graph import StateGraph, END | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# Agent model config β all use Gemma | |
AGENT_MODEL_CONFIG = { | |
"product_manager": { | |
"base": "unsloth/gemma-3-1b-it", | |
"adapter": "spandana30/project-manager-gemma" | |
}, | |
"project_manager": { | |
"base": "unsloth/gemma-3-1b-it", | |
"adapter": "spandana30/project-manager-gemma" | |
}, | |
"software_engineer": { | |
"base": "unsloth/gemma-3-1b-it", | |
"adapter": "spandana30/project-manager-gemma" | |
}, | |
"qa_engineer": { | |
"base": "unsloth/gemma-3-1b-it", | |
"adapter": "spandana30/project-manager-gemma" | |
} | |
} | |
def load_agent_model(base_id, adapter_id): | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_id, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None, | |
token=HF_TOKEN | |
) | |
model = PeftModel.from_pretrained(base_model, adapter_id, token=HF_TOKEN) | |
tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=HF_TOKEN) | |
return model.eval(), tokenizer | |
def call_model(prompt: str, model, tokenizer) -> str: | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=False, | |
temperature=0.3 | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
class AgentState(TypedDict): | |
messages: List[Dict[str, str]] | |
html: str | |
feedback: str | |
iteration: int | |
done: bool | |
timings: Dict[str, float] | |
def agent(prompt_template, state: AgentState, agent_key: str, timing_label: str): | |
start = time.time() | |
model, tokenizer = load_agent_model(**AGENT_MODEL_CONFIG[agent_key]) | |
prompt = prompt_template.format(**state) | |
response = call_model(prompt, model, tokenizer) | |
state["messages"].append({"role": agent_key, "content": response}) | |
state["timings"][timing_label] = time.time() - start | |
gc.collect() | |
return response | |
PROMPTS = { | |
"product_manager": "You're a Product Manager. Refine this user request:\n{messages[-1][content]}", | |
"project_manager": "You're a Project Manager. Break down this refined request:\n{messages[-1][content]}", | |
"software_engineer": "You're a Software Engineer. Generate HTML+CSS code for:\n{messages[-1][content]}", | |
"qa_engineer": "You're a QA Engineer. Review this HTML:\n{html}\nGive feedback or reply APPROVED." | |
} | |
def generate_ui(user_prompt: str, max_iter: int): | |
state: AgentState = { | |
"messages": [{"role": "user", "content": user_prompt}], | |
"html": "", | |
"feedback": "", | |
"iteration": 0, | |
"done": False, | |
"timings": {} | |
} | |
workflow = StateGraph(AgentState) | |
workflow.add_node("product_manager", lambda s: {"messages": s["messages"] + [{"role": "product_manager", "content": agent(PROMPTS["product_manager"], s, "product_manager", "product_manager")}]}) | |
workflow.add_node("project_manager", lambda s: {"messages": s["messages"] + [{"role": "project_manager", "content": agent(PROMPTS["project_manager"], s, "project_manager", "project_manager")}]}) | |
workflow.add_node("software_engineer", lambda s: { | |
"html": agent(PROMPTS["software_engineer"], s, "software_engineer", "software_engineer"), | |
"messages": s["messages"] + [{"role": "software_engineer", "content": s["html"]}] | |
}) | |
def qa_fn(s): | |
feedback = agent(PROMPTS["qa_engineer"], s, "qa_engineer", "qa_engineer") | |
done = "APPROVED" in feedback or s["iteration"] >= max_iter | |
return { | |
"feedback": feedback, | |
"done": done, | |
"iteration": s["iteration"] + 1, | |
"messages": s["messages"] + [{"role": "qa_engineer", "content": feedback}] | |
} | |
workflow.add_node("qa_engineer", qa_fn) | |
workflow.add_edge("product_manager", "project_manager") | |
workflow.add_edge("project_manager", "software_engineer") | |
workflow.add_edge("software_engineer", "qa_engineer") | |
workflow.add_conditional_edges("qa_engineer", lambda s: END if s["done"] else "software_engineer") | |
workflow.set_entry_point("product_manager") | |
app = workflow.compile() | |
final_state = app.invoke(state) | |
return final_state | |
def main(): | |
st.set_page_config(page_title="Multi-Agent UI Generator", layout="wide") | |
st.title(" Multi-Agent Collaboration") | |
max_iter = st.sidebar.slider("Max QA Iterations", 1, 5, 2) | |
prompt = st.text_area("Describe your UI:", "A landing page for a coffee shop with a hero image, menu, and contact form.", height=150) | |
if st.button("π Generate UI"): | |
with st.spinner("Agents working..."): | |
final = generate_ui(prompt, max_iter) | |
st.success("β UI Generated") | |
st.subheader("π Output HTML") | |
st.components.v1.html(final["html"], height=600, scrolling=True) | |
st.subheader("π§ Agent Messages") | |
for msg in final["messages"]: | |
st.markdown(f"**{msg['role'].title()}**:\n```\n{msg['content']}\n```") | |
if __name__ == "__main__": | |
main() | |