Spaces:
Build error
Build error
import streamlit as st | |
import os | |
import time | |
import gc | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from peft import PeftModel | |
from typing import Dict, List, TypedDict | |
from langgraph.graph import StateGraph, END | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
AGENT_MODEL_CONFIG = { | |
"product_manager": { | |
"base_id": "unsloth/gemma-3-1b-it", | |
"adapter_id": "spandana30/project-manager-gemma" | |
}, | |
"project_manager": { | |
"base_id": "unsloth/gemma-3-1b-it", | |
"adapter_id": "spandana30/project-manager-gemma" | |
}, | |
"software_engineer": { | |
"base_id": "unsloth/gemma-3-1b-it", | |
"adapter_id": "spandana30/project-manager-gemma" | |
}, | |
"qa_engineer": { | |
"base_id": "unsloth/gemma-3-1b-it", | |
"adapter_id": "spandana30/project-manager-gemma" | |
} | |
} | |
def load_agent_model(base_id, adapter_id): | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_id, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" if torch.cuda.is_available() else None, | |
token=HF_TOKEN | |
) | |
model = PeftModel.from_pretrained(base_model, adapter_id, token=HF_TOKEN) | |
tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=HF_TOKEN) | |
return model.eval(), tokenizer | |
def call_model(prompt: str, model, tokenizer) -> str: | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=512, | |
do_sample=False, | |
temperature=0.3 | |
) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
class AgentState(TypedDict): | |
messages: List[Dict[str, str]] | |
html: str | |
refined_request: str | |
final_prompt: str | |
feedback: str | |
iteration: int | |
done: bool | |
timings: Dict[str, float] | |
def agent(template: str, state: AgentState, agent_key: str, timing_label: str): | |
st.write(f'π Running agent: {agent_key}') | |
start = time.time() | |
model, tokenizer = load_agent_model(**AGENT_MODEL_CONFIG[agent_key]) | |
latest_input = ( | |
state.get("final_prompt") | |
or state.get("refined_request") | |
or state["messages"][-1]["content"] | |
) | |
prompt = template.format(user_input=latest_input, html=state.get("html", ""), final_prompt=state.get("final_prompt", "")) | |
st.write(f'π€ Prompt for {agent_key}:', prompt) | |
response = call_model(prompt, model, tokenizer) | |
st.write(f'π₯ Response from {agent_key}:', response[:500]) | |
state["messages"].append({"role": agent_key, "content": response}) | |
state["timings"][timing_label] = time.time() - start | |
gc.collect() | |
return response | |
PROMPTS = { | |
"product_manager": "{user_input}", | |
"project_manager": "{user_input}", | |
"software_engineer": "{final_prompt}", | |
"qa_engineer": "{html}" | |
} | |
def generate_ui(user_prompt: str, max_iter: int): | |
state: AgentState = { | |
"messages": [{"role": "user", "content": user_prompt}], | |
"html": "", | |
"refined_request": "", | |
"final_prompt": "", | |
"feedback": "", | |
"iteration": 0, | |
"done": False, | |
"timings": {} | |
} | |
workflow = StateGraph(AgentState) | |
workflow.add_node("product_manager", lambda s: { | |
"messages": s["messages"] + [{ | |
"role": "product_manager", | |
"content": (pm := agent(PROMPTS["product_manager"], s, "product_manager", "product_manager")) | |
}], | |
"refined_request": pm | |
}) | |
workflow.add_node("project_manager", lambda s: { | |
"messages": s["messages"] + [{ | |
"role": "project_manager", | |
"content": (pr := agent(PROMPTS["project_manager"], s, "project_manager", "project_manager")) | |
}], | |
"final_prompt": pr | |
}) | |
workflow.add_node("software_engineer", lambda s: { | |
"html": (html := agent(PROMPTS["software_engineer"], s, "software_engineer", "software_engineer")), | |
"messages": s["messages"] + [{"role": "software_engineer", "content": html}] | |
}) | |
def qa_fn(s): | |
feedback = agent(PROMPTS["qa_engineer"], s, "qa_engineer", "qa_engineer") | |
done = "APPROVED" in feedback or s["iteration"] >= max_iter | |
return { | |
"feedback": feedback, | |
"done": done, | |
"iteration": s["iteration"] + 1, | |
"messages": s["messages"] + [{"role": "qa_engineer", "content": feedback}] | |
} | |
workflow.add_node("qa_engineer", qa_fn) | |
workflow.add_edge("product_manager", "project_manager") | |
workflow.add_edge("project_manager", "software_engineer") | |
workflow.add_edge("software_engineer", "qa_engineer") | |
workflow.add_conditional_edges("qa_engineer", lambda s: END if s["done"] else "software_engineer") | |
workflow.set_entry_point("product_manager") | |
app = workflow.compile() | |
final_state = app.invoke(state) | |
return final_state | |
def main(): | |
st.set_page_config(page_title="Multi-Agent UI Generator", layout="wide") | |
st.title("π€ Multi-Agent Collaboration (Gemma Only)") | |
max_iter = st.sidebar.slider("Max QA Iterations", 1, 5, 2) | |
prompt = st.text_area("Describe your UI:", "A landing page for a coffee shop with a hero image, menu, and contact form.", height=150) | |
if st.button("π Generate UI"): | |
with st.spinner("Agents working..."): | |
final = generate_ui(prompt, max_iter) | |
st.success("β UI Generated") | |
st.write("π§ Final state:", final) | |
st.subheader("π Output HTML") | |
st.components.v1.html(final["html"], height=600, scrolling=True) | |
st.subheader("π§ Agent Messages") | |
for msg in final["messages"]: | |
st.markdown(f"**{msg['role'].title()}**:\n```\n{msg['content']}\n```") | |
if __name__ == "__main__": | |
main() | |