spandana30 commited on
Commit
2834a6c
Β·
verified Β·
1 Parent(s): ed6f467

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -48
app.py CHANGED
@@ -4,16 +4,19 @@ import time
4
  import base64
5
  from typing import Dict, List, TypedDict
6
  from langgraph.graph import StateGraph, END
7
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
 
9
- # Load CodeLLaMA locally
10
- model_id = "codellama/CodeLlama-7b-hf"
11
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
12
- tokenizer = AutoTokenizer.from_pretrained(model_id)
13
- generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
 
14
 
15
  class AgentState(TypedDict):
16
  messages: List[Dict[str, str]]
 
 
17
  design_specs: str
18
  html: str
19
  css: str
@@ -22,16 +25,16 @@ class AgentState(TypedDict):
22
  done: bool
23
  timings: Dict[str, float]
24
 
25
- PRODUCT_MANAGER_PROMPT = """You're a product manager. Given the user request:
26
  {user_request}
27
- Break it down into clear features and priorities."""
28
 
29
- PROJECT_MANAGER_PROMPT = """You're a project manager. Based on these features:
30
- {features}
31
- Draft a quick development plan with key tasks and timeline."""
32
 
33
- ARCHITECT_PROMPT = """You're a software architect. Create design specs for:
34
- {user_request}
35
  Include:
36
  1. Color palette (primary, secondary, accent)
37
  2. Font choices
@@ -55,16 +58,7 @@ Check for:
55
  1. Visual quality
56
  2. Responsiveness
57
  3. Functionality
58
- Reply "APPROVED" if perfect, or suggest improvements."""
59
-
60
- def call_model(prompt: str, max_retries=3) -> str:
61
- try:
62
- outputs = generator(prompt, max_new_tokens=1000, temperature=0.3)
63
- return outputs[0]["generated_text"]
64
- except Exception as e:
65
- st.error(f"Local model call failed: {str(e)}")
66
- st.stop()
67
- return "<html><body><h1>Error generating UI</h1></body></html>"
68
 
69
  def time_agent(agent_func, state: AgentState, label: str):
70
  start = time.time()
@@ -75,46 +69,58 @@ def time_agent(agent_func, state: AgentState, label: str):
75
  return result
76
 
77
  def product_manager_agent(state: AgentState):
78
- features = call_model(PRODUCT_MANAGER_PROMPT.format(user_request=state["messages"][-1]["content"]))
79
- return {"messages": state["messages"] + [{"role": "product_manager", "content": features}]}
 
 
 
80
 
81
  def project_manager_agent(state: AgentState):
82
- features_msg = next((m["content"] for m in state["messages"] if m["role"] == "product_manager"), "")
83
- plan = call_model(PROJECT_MANAGER_PROMPT.format(features=features_msg))
84
- return {"messages": state["messages"] + [{"role": "project_manager", "content": plan}]}
85
-
86
- def software_architect_agent(state: AgentState):
87
- specs = call_model(ARCHITECT_PROMPT.format(user_request=state["messages"][-1]["content"]))
88
- return {"design_specs": specs, "messages": state["messages"] + [{"role": "software_architect", "content": specs}]}
 
 
 
 
 
89
 
90
  def engineer_agent(state: AgentState):
91
- html = call_model(ENGINEER_PROMPT.format(design_specs=state["design_specs"]))
 
 
 
92
  if not html.strip().startswith("<!DOCTYPE"):
93
- html = f"""<!DOCTYPE html>
94
- <html><head><meta charset='UTF-8'><meta name='viewport' content='width=device-width, initial-scale=1.0'>
95
- <title>Generated UI</title></head><body>{html}</body></html>"""
96
  return {"html": html, "messages": state["messages"] + [{"role": "software_engineer", "content": html}]}
97
 
98
  def qa_agent(state: AgentState, max_iter: int):
99
- feedback = call_model(QA_PROMPT.format(html=state["html"]))
 
 
 
100
  done = "APPROVED" in feedback or state["iteration"] >= max_iter
101
  return {"feedback": feedback, "done": done, "iteration": state["iteration"] + 1,
102
  "messages": state["messages"] + [{"role": "qa", "content": feedback}]}
103
 
104
  def generate_ui(user_request: str, max_iter: int):
105
- state = {"messages": [{"role": "user", "content": user_request}], "design_specs": "", "html": "",
106
- "css": "", "feedback": "", "iteration": 0, "done": False, "timings": {}}
107
 
108
  workflow = StateGraph(AgentState)
109
  workflow.add_node("product_manager", lambda s: time_agent(product_manager_agent, s, "product_manager"))
110
  workflow.add_node("project_manager", lambda s: time_agent(project_manager_agent, s, "project_manager"))
111
- workflow.add_node("software_architect", lambda s: time_agent(software_architect_agent, s, "software_architect"))
112
  workflow.add_node("software_engineer", lambda s: time_agent(engineer_agent, s, "software_engineer"))
113
  workflow.add_node("qa", lambda s: time_agent(lambda x: qa_agent(x, max_iter), s, "qa"))
114
 
115
  workflow.add_edge("product_manager", "project_manager")
116
- workflow.add_edge("project_manager", "software_architect")
117
- workflow.add_edge("software_architect", "software_engineer")
118
  workflow.add_edge("software_engineer", "qa")
119
  workflow.add_conditional_edges("qa", lambda s: END if s["done"] else "software_engineer")
120
  workflow.set_entry_point("product_manager")
@@ -138,11 +144,11 @@ def main():
138
  st.success("βœ… UI Generated Successfully!")
139
  st.components.v1.html(html, height=600, scrolling=True)
140
 
141
- st.subheader("πŸ“… Download HTML")
142
  b64 = base64.b64encode(html.encode()).decode()
143
  st.markdown(f'<a href="data:file/html;base64,{b64}" download="ui.html">Download HTML</a>', unsafe_allow_html=True)
144
 
145
- st.subheader("🧐 Agent Communication Log")
146
  history_text = ""
147
  for msg in final_state["messages"]:
148
  role = msg["role"].replace("_", " ").title()
@@ -154,15 +160,15 @@ def main():
154
  st.markdown(
155
  f'<a href="data:file/txt;base64,{b64_hist}" download="agent_communication.txt" '
156
  'style="padding: 0.4em 1em; background: #4CAF50; color: white; border-radius: 0.3em; text-decoration: none;">'
157
- 'πŸ“… Download Communication Log</a>',
158
  unsafe_allow_html=True
159
  )
160
 
161
  st.subheader("πŸ“Š Performance")
162
  st.write(f"⏱️ Total Time: {total_time:.2f} seconds")
163
  st.write(f"πŸ” Iterations: {final_state['iteration']}")
164
- for stage in ["product_manager", "project_manager", "software_architect", "software_engineer", "qa"]:
165
- st.write(f"πŸ€– {stage.title().replace('_', ' ')} Time: {final_state['timings'].get(stage, 0):.2f}s")
166
 
167
  if __name__ == "__main__":
168
- main()
 
4
  import base64
5
  from typing import Dict, List, TypedDict
6
  from langgraph.graph import StateGraph, END
7
+ from huggingface_hub import InferenceClient
8
 
9
+ # Individual clients per role
10
+ product_manager_client = InferenceClient("unsloth/mistral-7b-bnb-4bit", token=st.secrets["HF_TOKEN"])
11
+ project_manager_client = InferenceClient("unsloth/gemma-3-1b-it", token=st.secrets["HF_TOKEN"])
12
+ software_architect_client = InferenceClient("unsloth/c4ai-command-r-08-2024-bnb-4bit", token=st.secrets["HF_TOKEN"])
13
+ software_engineer_client = InferenceClient("codellama/CodeLlama-7b-hf", token=st.secrets["HF_TOKEN"])
14
+ qa_client = InferenceClient("codellama/CodeLlama-7b-hf", token=st.secrets["HF_TOKEN"])
15
 
16
  class AgentState(TypedDict):
17
  messages: List[Dict[str, str]]
18
+ product_vision: str
19
+ project_plan: str
20
  design_specs: str
21
  html: str
22
  css: str
 
25
  done: bool
26
  timings: Dict[str, float]
27
 
28
+ PRODUCT_MANAGER_PROMPT = """You're a Product Manager. Interpret this user request:
29
  {user_request}
30
+ Define the high-level product goals, features, and user stories."""
31
 
32
+ PROJECT_MANAGER_PROMPT = """You're a Project Manager. Based on this feature list:
33
+ {product_vision}
34
+ Create a project plan with key milestones and task assignments."""
35
 
36
+ DESIGNER_PROMPT = """You're a UI designer. Create design specs for:
37
+ {project_plan}
38
  Include:
39
  1. Color palette (primary, secondary, accent)
40
  2. Font choices
 
58
  1. Visual quality
59
  2. Responsiveness
60
  3. Functionality
61
+ Reply \"APPROVED\" if perfect, or suggest improvements."""
 
 
 
 
 
 
 
 
 
62
 
63
  def time_agent(agent_func, state: AgentState, label: str):
64
  start = time.time()
 
69
  return result
70
 
71
  def product_manager_agent(state: AgentState):
72
+ vision = product_manager_client.text_generation(
73
+ PRODUCT_MANAGER_PROMPT.format(user_request=state["messages"][-1]["content"]),
74
+ max_new_tokens=1000, temperature=0.3, return_full_text=False
75
+ )
76
+ return {"product_vision": vision, "messages": state["messages"] + [{"role": "product_manager", "content": vision}]}
77
 
78
  def project_manager_agent(state: AgentState):
79
+ plan = project_manager_client.text_generation(
80
+ PROJECT_MANAGER_PROMPT.format(product_vision=state["product_vision"]),
81
+ max_new_tokens=1000, temperature=0.3, return_full_text=False
82
+ )
83
+ return {"project_plan": plan, "messages": state["messages"] + [{"role": "project_manager", "content": plan}]}
84
+
85
+ def designer_agent(state: AgentState):
86
+ specs = software_architect_client.text_generation(
87
+ DESIGNER_PROMPT.format(project_plan=state["project_plan"]),
88
+ max_new_tokens=1000, temperature=0.3, return_full_text=False
89
+ )
90
+ return {"design_specs": specs, "messages": state["messages"] + [{"role": "designer", "content": specs}]}
91
 
92
  def engineer_agent(state: AgentState):
93
+ html = software_engineer_client.text_generation(
94
+ ENGINEER_PROMPT.format(design_specs=state["design_specs"]),
95
+ max_new_tokens=3000, temperature=0.3, return_full_text=False
96
+ )
97
  if not html.strip().startswith("<!DOCTYPE"):
98
+ html = f"""<!DOCTYPE html><html><head><meta charset='UTF-8'><meta name='viewport' content='width=device-width, initial-scale=1.0'><title>Generated UI</title></head><body>{html}</body></html>"""
 
 
99
  return {"html": html, "messages": state["messages"] + [{"role": "software_engineer", "content": html}]}
100
 
101
  def qa_agent(state: AgentState, max_iter: int):
102
+ feedback = qa_client.text_generation(
103
+ QA_PROMPT.format(html=state["html"]),
104
+ max_new_tokens=1000, temperature=0.3, return_full_text=False
105
+ )
106
  done = "APPROVED" in feedback or state["iteration"] >= max_iter
107
  return {"feedback": feedback, "done": done, "iteration": state["iteration"] + 1,
108
  "messages": state["messages"] + [{"role": "qa", "content": feedback}]}
109
 
110
  def generate_ui(user_request: str, max_iter: int):
111
+ state = {"messages": [{"role": "user", "content": user_request}], "product_vision": "", "project_plan": "",
112
+ "design_specs": "", "html": "", "css": "", "feedback": "", "iteration": 0, "done": False, "timings": {}}
113
 
114
  workflow = StateGraph(AgentState)
115
  workflow.add_node("product_manager", lambda s: time_agent(product_manager_agent, s, "product_manager"))
116
  workflow.add_node("project_manager", lambda s: time_agent(project_manager_agent, s, "project_manager"))
117
+ workflow.add_node("designer", lambda s: time_agent(designer_agent, s, "designer"))
118
  workflow.add_node("software_engineer", lambda s: time_agent(engineer_agent, s, "software_engineer"))
119
  workflow.add_node("qa", lambda s: time_agent(lambda x: qa_agent(x, max_iter), s, "qa"))
120
 
121
  workflow.add_edge("product_manager", "project_manager")
122
+ workflow.add_edge("project_manager", "designer")
123
+ workflow.add_edge("designer", "software_engineer")
124
  workflow.add_edge("software_engineer", "qa")
125
  workflow.add_conditional_edges("qa", lambda s: END if s["done"] else "software_engineer")
126
  workflow.set_entry_point("product_manager")
 
144
  st.success("βœ… UI Generated Successfully!")
145
  st.components.v1.html(html, height=600, scrolling=True)
146
 
147
+ st.subheader("πŸ“₯ Download HTML")
148
  b64 = base64.b64encode(html.encode()).decode()
149
  st.markdown(f'<a href="data:file/html;base64,{b64}" download="ui.html">Download HTML</a>', unsafe_allow_html=True)
150
 
151
+ st.subheader("🧠 Agent Communication Log")
152
  history_text = ""
153
  for msg in final_state["messages"]:
154
  role = msg["role"].replace("_", " ").title()
 
160
  st.markdown(
161
  f'<a href="data:file/txt;base64,{b64_hist}" download="agent_communication.txt" '
162
  'style="padding: 0.4em 1em; background: #4CAF50; color: white; border-radius: 0.3em; text-decoration: none;">'
163
+ 'πŸ“₯ Download Communication Log</a>',
164
  unsafe_allow_html=True
165
  )
166
 
167
  st.subheader("πŸ“Š Performance")
168
  st.write(f"⏱️ Total Time: {total_time:.2f} seconds")
169
  st.write(f"πŸ” Iterations: {final_state['iteration']}")
170
+ for stage in ["product_manager", "project_manager", "designer", "software_engineer", "qa"]:
171
+ st.write(f"🧩 {stage.title().replace('_', ' ')} Time: {final_state['timings'].get(stage, 0):.2f}s")
172
 
173
  if __name__ == "__main__":
174
+ main()