Spaces:
Running
Running
update to add provider selection
Browse files
app.py
CHANGED
@@ -227,10 +227,11 @@ DEMO_LIST = [
|
|
227 |
|
228 |
# HF Inference Client
|
229 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
|
|
|
|
230 |
|
231 |
-
def get_inference_client(model_id):
|
232 |
-
"""Return an InferenceClient with provider based on model_id."""
|
233 |
-
provider = "groq" if model_id == "moonshotai/Kimi-K2-Instruct" else "auto"
|
234 |
return InferenceClient(
|
235 |
provider=provider,
|
236 |
api_key=HF_TOKEN,
|
@@ -940,20 +941,24 @@ The HTML code above contains the complete original website structure with all im
|
|
940 |
except Exception as e:
|
941 |
return f"Error extracting website content: {str(e)}"
|
942 |
|
943 |
-
def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[History], _current_model: Dict, enable_search: bool = False, language: str = "html"):
|
944 |
if query is None:
|
945 |
query = ''
|
946 |
if _history is None:
|
947 |
_history = []
|
948 |
-
|
|
|
|
|
|
|
|
|
949 |
# Check if there's existing HTML content in history to determine if this is a modification request
|
950 |
has_existing_html = False
|
951 |
-
|
952 |
-
|
953 |
-
last_assistant_msg = _history[-1][1]
|
954 |
if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
|
955 |
has_existing_html = True
|
956 |
-
|
957 |
# Choose system prompt based on context
|
958 |
if has_existing_html:
|
959 |
# Use follow-up prompt for modifying existing HTML
|
@@ -964,9 +969,9 @@ def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optio
|
|
964 |
system_prompt = HTML_SYSTEM_PROMPT_WITH_SEARCH if enable_search else HTML_SYSTEM_PROMPT
|
965 |
else:
|
966 |
system_prompt = GENERIC_SYSTEM_PROMPT_WITH_SEARCH.format(language=language) if enable_search else GENERIC_SYSTEM_PROMPT.format(language=language)
|
967 |
-
|
968 |
messages = history_to_messages(_history, system_prompt)
|
969 |
-
|
970 |
# Extract file text and append to query if file is present
|
971 |
file_text = ""
|
972 |
if file:
|
@@ -974,7 +979,7 @@ def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optio
|
|
974 |
if file_text:
|
975 |
file_text = file_text[:5000] # Limit to 5000 chars for prompt size
|
976 |
query = f"{query}\n\n[Reference file content below]\n{file_text}"
|
977 |
-
|
978 |
# Extract website content and append to query if website URL is present
|
979 |
website_text = ""
|
980 |
if website_url and website_url.strip():
|
@@ -994,12 +999,12 @@ Since I couldn't extract the website content, please provide additional details
|
|
994 |
|
995 |
This will help me create a better design for you."""
|
996 |
query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"
|
997 |
-
|
998 |
# Enhance query with search if enabled
|
999 |
enhanced_query = enhance_query_with_search(query, enable_search)
|
1000 |
-
|
1001 |
# Use dynamic client based on selected model
|
1002 |
-
client = get_inference_client(_current_model["id"])
|
1003 |
|
1004 |
if image is not None:
|
1005 |
messages.append(create_multimodal_message(enhanced_query, image))
|
@@ -1014,7 +1019,8 @@ This will help me create a better design for you."""
|
|
1014 |
)
|
1015 |
content = ""
|
1016 |
for chunk in completion:
|
1017 |
-
if chunk.choices
|
|
|
1018 |
content += chunk.choices[0].delta.content
|
1019 |
clean_code = remove_code_block(content)
|
1020 |
search_status = " (with web search)" if enable_search and tavily_client else ""
|
@@ -1027,7 +1033,7 @@ This will help me create a better design for you."""
|
|
1027 |
sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
|
1028 |
}
|
1029 |
else:
|
1030 |
-
last_html = _history[-1][1] if _history else ""
|
1031 |
modified_html = apply_search_replace_changes(last_html, clean_code)
|
1032 |
clean_html = remove_code_block(modified_html)
|
1033 |
yield {
|
@@ -1041,6 +1047,8 @@ This will help me create a better design for you."""
|
|
1041 |
history_output: history_to_chatbot_messages(_history),
|
1042 |
sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
|
1043 |
}
|
|
|
|
|
1044 |
# Handle response based on whether this is a modification or new generation
|
1045 |
if has_existing_html:
|
1046 |
# Fallback: If the model returns a full HTML file, use it directly
|
@@ -1048,14 +1056,11 @@ This will help me create a better design for you."""
|
|
1048 |
if final_code.strip().startswith("<!DOCTYPE html>") or final_code.strip().startswith("<html"):
|
1049 |
clean_html = final_code
|
1050 |
else:
|
1051 |
-
last_html = _history[-1][1] if _history else ""
|
1052 |
modified_html = apply_search_replace_changes(last_html, final_code)
|
1053 |
clean_html = remove_code_block(modified_html)
|
1054 |
# Update history with the cleaned HTML
|
1055 |
-
_history
|
1056 |
-
'role': 'assistant',
|
1057 |
-
'content': clean_html
|
1058 |
-
}])
|
1059 |
yield {
|
1060 |
code_output: clean_html,
|
1061 |
history: _history,
|
@@ -1064,10 +1069,7 @@ This will help me create a better design for you."""
|
|
1064 |
}
|
1065 |
else:
|
1066 |
# Regular generation - use the content as is
|
1067 |
-
_history
|
1068 |
-
'role': 'assistant',
|
1069 |
-
'content': content
|
1070 |
-
}])
|
1071 |
yield {
|
1072 |
code_output: remove_code_block(content),
|
1073 |
history: _history,
|
@@ -1156,6 +1158,16 @@ with gr.Blocks(
|
|
1156 |
label="Model",
|
1157 |
visible=True # Always visible
|
1158 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1159 |
gr.Markdown("**Quick start**", visible=True)
|
1160 |
with gr.Column(visible=True) as quick_examples_col:
|
1161 |
for i, demo_item in enumerate(DEMO_LIST[:3]):
|
@@ -1251,7 +1263,7 @@ with gr.Blocks(
|
|
1251 |
|
1252 |
btn.click(
|
1253 |
generation_code,
|
1254 |
-
inputs=[input, image_input, file_input, website_url_input, setting, history, current_model, search_toggle, language_dropdown],
|
1255 |
outputs=[code_output, history, sandbox, history_output]
|
1256 |
)
|
1257 |
# Update preview when code or language changes
|
@@ -1259,5 +1271,14 @@ with gr.Blocks(
|
|
1259 |
language_dropdown.change(preview_logic, inputs=[code_output, language_dropdown], outputs=sandbox)
|
1260 |
clear_btn.click(clear_history, outputs=[history, history_output, file_input, website_url_input])
|
1261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1262 |
if __name__ == "__main__":
|
1263 |
demo.queue(api_open=False, default_concurrency_limit=20).launch(ssr_mode=True, mcp_server=False, show_api=False)
|
|
|
227 |
|
228 |
# HF Inference Client
|
229 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
230 |
+
if not HF_TOKEN:
|
231 |
+
raise RuntimeError("HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token.")
|
232 |
|
233 |
+
def get_inference_client(model_id, provider="auto"):
|
234 |
+
"""Return an InferenceClient with provider based on model_id and user selection."""
|
|
|
235 |
return InferenceClient(
|
236 |
provider=provider,
|
237 |
api_key=HF_TOKEN,
|
|
|
941 |
except Exception as e:
|
942 |
return f"Error extracting website content: {str(e)}"
|
943 |
|
944 |
+
def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[History], _current_model: Dict, enable_search: bool = False, language: str = "html", provider: str = "auto"):
|
945 |
if query is None:
|
946 |
query = ''
|
947 |
if _history is None:
|
948 |
_history = []
|
949 |
+
# Ensure _history is always a list of lists with at least 2 elements per item
|
950 |
+
if not isinstance(_history, list):
|
951 |
+
_history = []
|
952 |
+
_history = [h for h in _history if isinstance(h, list) and len(h) == 2]
|
953 |
+
|
954 |
# Check if there's existing HTML content in history to determine if this is a modification request
|
955 |
has_existing_html = False
|
956 |
+
last_assistant_msg = ""
|
957 |
+
if _history and len(_history[-1]) > 1:
|
958 |
+
last_assistant_msg = _history[-1][1]
|
959 |
if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
|
960 |
has_existing_html = True
|
961 |
+
|
962 |
# Choose system prompt based on context
|
963 |
if has_existing_html:
|
964 |
# Use follow-up prompt for modifying existing HTML
|
|
|
969 |
system_prompt = HTML_SYSTEM_PROMPT_WITH_SEARCH if enable_search else HTML_SYSTEM_PROMPT
|
970 |
else:
|
971 |
system_prompt = GENERIC_SYSTEM_PROMPT_WITH_SEARCH.format(language=language) if enable_search else GENERIC_SYSTEM_PROMPT.format(language=language)
|
972 |
+
|
973 |
messages = history_to_messages(_history, system_prompt)
|
974 |
+
|
975 |
# Extract file text and append to query if file is present
|
976 |
file_text = ""
|
977 |
if file:
|
|
|
979 |
if file_text:
|
980 |
file_text = file_text[:5000] # Limit to 5000 chars for prompt size
|
981 |
query = f"{query}\n\n[Reference file content below]\n{file_text}"
|
982 |
+
|
983 |
# Extract website content and append to query if website URL is present
|
984 |
website_text = ""
|
985 |
if website_url and website_url.strip():
|
|
|
999 |
|
1000 |
This will help me create a better design for you."""
|
1001 |
query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"
|
1002 |
+
|
1003 |
# Enhance query with search if enabled
|
1004 |
enhanced_query = enhance_query_with_search(query, enable_search)
|
1005 |
+
|
1006 |
# Use dynamic client based on selected model
|
1007 |
+
client = get_inference_client(_current_model["id"], provider)
|
1008 |
|
1009 |
if image is not None:
|
1010 |
messages.append(create_multimodal_message(enhanced_query, image))
|
|
|
1019 |
)
|
1020 |
content = ""
|
1021 |
for chunk in completion:
|
1022 |
+
# Only process if chunk.choices is non-empty
|
1023 |
+
if hasattr(chunk, "choices") and chunk.choices and hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
|
1024 |
content += chunk.choices[0].delta.content
|
1025 |
clean_code = remove_code_block(content)
|
1026 |
search_status = " (with web search)" if enable_search and tavily_client else ""
|
|
|
1033 |
sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
|
1034 |
}
|
1035 |
else:
|
1036 |
+
last_html = _history[-1][1] if _history and len(_history[-1]) > 1 else ""
|
1037 |
modified_html = apply_search_replace_changes(last_html, clean_code)
|
1038 |
clean_html = remove_code_block(modified_html)
|
1039 |
yield {
|
|
|
1047 |
history_output: history_to_chatbot_messages(_history),
|
1048 |
sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
|
1049 |
}
|
1050 |
+
# Skip chunks with empty choices (end of stream)
|
1051 |
+
# Do not treat as error
|
1052 |
# Handle response based on whether this is a modification or new generation
|
1053 |
if has_existing_html:
|
1054 |
# Fallback: If the model returns a full HTML file, use it directly
|
|
|
1056 |
if final_code.strip().startswith("<!DOCTYPE html>") or final_code.strip().startswith("<html"):
|
1057 |
clean_html = final_code
|
1058 |
else:
|
1059 |
+
last_html = _history[-1][1] if _history and len(_history[-1]) > 1 else ""
|
1060 |
modified_html = apply_search_replace_changes(last_html, final_code)
|
1061 |
clean_html = remove_code_block(modified_html)
|
1062 |
# Update history with the cleaned HTML
|
1063 |
+
_history.append([query, clean_html])
|
|
|
|
|
|
|
1064 |
yield {
|
1065 |
code_output: clean_html,
|
1066 |
history: _history,
|
|
|
1069 |
}
|
1070 |
else:
|
1071 |
# Regular generation - use the content as is
|
1072 |
+
_history.append([query, content])
|
|
|
|
|
|
|
1073 |
yield {
|
1074 |
code_output: remove_code_block(content),
|
1075 |
history: _history,
|
|
|
1158 |
label="Model",
|
1159 |
visible=True # Always visible
|
1160 |
)
|
1161 |
+
provider_choices = [
|
1162 |
+
"auto", "black-forest-labs", "cerebras", "cohere", "fal-ai", "featherless-ai", "fireworks-ai", "groq", "hf-inference", "hyperbolic", "nebius", "novita", "nscale", "openai", "replicate", "sambanova", "together"
|
1163 |
+
]
|
1164 |
+
provider_dropdown = gr.Dropdown(
|
1165 |
+
choices=provider_choices,
|
1166 |
+
value="auto",
|
1167 |
+
label="Provider",
|
1168 |
+
visible=True
|
1169 |
+
)
|
1170 |
+
provider_state = gr.State("auto")
|
1171 |
gr.Markdown("**Quick start**", visible=True)
|
1172 |
with gr.Column(visible=True) as quick_examples_col:
|
1173 |
for i, demo_item in enumerate(DEMO_LIST[:3]):
|
|
|
1263 |
|
1264 |
btn.click(
|
1265 |
generation_code,
|
1266 |
+
inputs=[input, image_input, file_input, website_url_input, setting, history, current_model, search_toggle, language_dropdown, provider_state],
|
1267 |
outputs=[code_output, history, sandbox, history_output]
|
1268 |
)
|
1269 |
# Update preview when code or language changes
|
|
|
1271 |
language_dropdown.change(preview_logic, inputs=[code_output, language_dropdown], outputs=sandbox)
|
1272 |
clear_btn.click(clear_history, outputs=[history, history_output, file_input, website_url_input])
|
1273 |
|
1274 |
+
def on_provider_change(provider):
|
1275 |
+
return provider
|
1276 |
+
|
1277 |
+
provider_dropdown.change(
|
1278 |
+
on_provider_change,
|
1279 |
+
inputs=provider_dropdown,
|
1280 |
+
outputs=provider_state
|
1281 |
+
)
|
1282 |
+
|
1283 |
if __name__ == "__main__":
|
1284 |
demo.queue(api_open=False, default_concurrency_limit=20).launch(ssr_mode=True, mcp_server=False, show_api=False)
|