import streamlit as st import os import time import gc import torch from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from typing import Dict, List, TypedDict from langgraph.graph import StateGraph, END HF_TOKEN = os.getenv("HF_TOKEN") AGENT_MODEL_CONFIG = { "product_manager": { "base_id": "unsloth/mistral-7b-bnb-4bit", "adapter_id": "spandana30/product-manager-mistral" }, "project_manager": { "base_id": "unsloth/gemma-3-1b-it", "adapter_id": "spandana30/project-manager-gemma" }, "designer": { "base_id": "unsloth/gemma-3-1b-it", "adapter_id": "spandana30/project-manager-gemma" }, "software_engineer": { "base_id": "codellama/CodeLLaMA-7b-hf", "adapter_id": "spandana30/software-engineer-codellama" }, "qa_engineer": { "base_id": "codellama/CodeLLaMA-7b-hf", "adapter_id": "spandana30/software-engineer-codellama" } } @st.cache_resource def load_agent_model(base_id, adapter_id): base_model = AutoModelForCausalLM.from_pretrained( base_id, torch_dtype=torch.float16, device_map="auto", load_in_4bit=True, token=HF_TOKEN ) model = PeftModel.from_pretrained(base_model, adapter_id, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(adapter_id, token=HF_TOKEN) return model.eval(), tokenizer def call_model(prompt: str, model, tokenizer) -> str: inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device) outputs = model.generate( **inputs, max_new_tokens=1024, do_sample=False, temperature=0.3 ) return tokenizer.decode(outputs[0], skip_special_tokens=True) class AgentState(TypedDict): messages: List[Dict[str, str]] product_vision: str project_plan: str design_specs: str html: str feedback: str iteration: int done: bool timings: Dict[str, float] def agent(template: str, state: AgentState, agent_key: str, timing_label: str): st.write(f'🛠 Running agent: {agent_key}') start = time.time() model, tokenizer = load_agent_model(**AGENT_MODEL_CONFIG[agent_key]) prompt = template.format( user_request=state["messages"][0]["content"], product_vision=state.get("product_vision", ""), project_plan=state.get("project_plan", ""), design_specs=state.get("design_specs", ""), html=state.get("html", "") ) st.write(f'📤 Prompt for {agent_key}:', prompt) response = call_model(prompt, model, tokenizer) st.write(f'📥 Response from {agent_key}:', response[:500]) state["messages"].append({"role": agent_key, "content": response}) state["timings"][timing_label] = time.time() - start gc.collect() return response PROMPTS = { "product_manager": ( "You're a Product Manager. Interpret this user request:\n" "{user_request}\n" "Define the high-level product goals, features, and user stories." ), "project_manager": ( "You're a Project Manager. Based on this feature list:\n" "{product_vision}\n" "Create a project plan with key milestones and task assignments." ), "designer": ( "You're a UI designer. Create design specs for:\n" "{project_plan}\n" "Include:\n" "1. Color palette (primary, secondary, accent)\n" "2. Font choices\n" "3. Layout structure\n" "4. Component styles\n" "Don't write code - just design guidance." ), "software_engineer": ( "Create a complete HTML page with embedded CSS for:\n" "{design_specs}\n" "Requirements:\n" "1. Full HTML document with \n" "2. CSS inside