Spaces:
Build error
Build error
File size: 5,831 Bytes
a305086 facf100 a305086 facf100 a305086 3079b0f eafe2f3 3079b0f 8f974c7 3079b0f 8f974c7 3079b0f 8f974c7 3079b0f facf100 8f974c7 facf100 3079b0f facf100 3079b0f a305086 23084ee 46c0430 a305086 23084ee 6a50430 facf100 2e1d572 23084ee 6a50430 facf100 6a50430 facf100 3079b0f 23084ee 3079b0f a305086 facf100 23084ee 46c0430 facf100 a305086 23084ee 8f974c7 23084ee 8f974c7 23084ee 8f974c7 23084ee 46c0430 23084ee 8f974c7 23084ee facf100 23084ee facf100 23084ee 8f974c7 facf100 23084ee facf100 3079b0f facf100 3079b0f a305086 facf100 a305086 facf100 8f974c7 facf100 a305086 facf100 6a50430 facf100 a305086 6a50430 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import streamlit as st
import os
import time
import gc
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from typing import Dict, List, TypedDict
from langgraph.graph import StateGraph, END
HF_TOKEN = os.getenv("HF_TOKEN")
AGENT_MODEL_CONFIG = {
"product_manager": {
"base_id": "unsloth/gemma-3-1b-it",
"adapter_id": "spandana30/project-manager-gemma"
},
"project_manager": {
"base_id": "unsloth/gemma-3-1b-it",
"adapter_id": "spandana30/project-manager-gemma"
},
"software_engineer": {
"base_id": "unsloth/gemma-3-1b-it",
"adapter_id": "spandana30/project-manager-gemma"
},
"qa_engineer": {
"base_id": "unsloth/gemma-3-1b-it",
"adapter_id": "spandana30/project-manager-gemma"
}
}
@st.cache_resource
def load_agent_model(base_id, adapter_id):
base_model = AutoModelForCausalLM.from_pretrained(
base_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
token=HF_TOKEN
)
model = PeftModel.from_pretrained(base_model, adapter_id, token=HF_TOKEN)
tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=HF_TOKEN)
return model.eval(), tokenizer
def call_model(prompt: str, model, tokenizer) -> str:
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
temperature=0.3
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
class AgentState(TypedDict):
messages: List[Dict[str, str]]
html: str
refined_request: str
final_prompt: str
feedback: str
iteration: int
done: bool
timings: Dict[str, float]
def agent(template: str, state: AgentState, agent_key: str, timing_label: str):
st.write(f'π Running agent: {agent_key}')
start = time.time()
model, tokenizer = load_agent_model(**AGENT_MODEL_CONFIG[agent_key])
latest_input = (
state.get("final_prompt")
or state.get("refined_request")
or state["messages"][-1]["content"]
)
prompt = template.format(user_input=latest_input, html=state.get("html", ""), final_prompt=state.get("final_prompt", ""))
st.write(f'π€ Prompt for {agent_key}:', prompt)
response = call_model(prompt, model, tokenizer)
st.write(f'π₯ Response from {agent_key}:', response[:500])
state["messages"].append({"role": agent_key, "content": response})
state["timings"][timing_label] = time.time() - start
gc.collect()
return response
PROMPTS = {
"product_manager": "{user_input}",
"project_manager": "{user_input}",
"software_engineer": "{final_prompt}",
"qa_engineer": "{html}"
}
def generate_ui(user_prompt: str, max_iter: int):
state: AgentState = {
"messages": [{"role": "user", "content": user_prompt}],
"html": "",
"refined_request": "",
"final_prompt": "",
"feedback": "",
"iteration": 0,
"done": False,
"timings": {}
}
workflow = StateGraph(AgentState)
workflow.add_node("product_manager", lambda s: {
"messages": s["messages"] + [{
"role": "product_manager",
"content": (pm := agent(PROMPTS["product_manager"], s, "product_manager", "product_manager"))
}],
"refined_request": pm
})
workflow.add_node("project_manager", lambda s: {
"messages": s["messages"] + [{
"role": "project_manager",
"content": (pr := agent(PROMPTS["project_manager"], s, "project_manager", "project_manager"))
}],
"final_prompt": pr
})
workflow.add_node("software_engineer", lambda s: {
"html": (html := agent(PROMPTS["software_engineer"], s, "software_engineer", "software_engineer")),
"messages": s["messages"] + [{"role": "software_engineer", "content": html}]
})
def qa_fn(s):
feedback = agent(PROMPTS["qa_engineer"], s, "qa_engineer", "qa_engineer")
done = "APPROVED" in feedback or s["iteration"] >= max_iter
return {
"feedback": feedback,
"done": done,
"iteration": s["iteration"] + 1,
"messages": s["messages"] + [{"role": "qa_engineer", "content": feedback}]
}
workflow.add_node("qa_engineer", qa_fn)
workflow.add_edge("product_manager", "project_manager")
workflow.add_edge("project_manager", "software_engineer")
workflow.add_edge("software_engineer", "qa_engineer")
workflow.add_conditional_edges("qa_engineer", lambda s: END if s["done"] else "software_engineer")
workflow.set_entry_point("product_manager")
app = workflow.compile()
final_state = app.invoke(state)
return final_state
def main():
st.set_page_config(page_title="Multi-Agent UI Generator", layout="wide")
st.title("π€ Multi-Agent Collaboration (Gemma Only)")
max_iter = st.sidebar.slider("Max QA Iterations", 1, 5, 2)
prompt = st.text_area("Describe your UI:", "A landing page for a coffee shop with a hero image, menu, and contact form.", height=150)
if st.button("π Generate UI"):
with st.spinner("Agents working..."):
final = generate_ui(prompt, max_iter)
st.success("β
UI Generated")
st.write("π§ Final state:", final)
st.subheader("π Output HTML")
st.components.v1.html(final["html"], height=600, scrolling=True)
st.subheader("π§ Agent Messages")
for msg in final["messages"]:
st.markdown(f"**{msg['role'].title()}**:\n```\n{msg['content']}\n```")
if __name__ == "__main__":
main()
|