akhaliq HF Staff commited on
Commit
f012c97
·
1 Parent(s): b737269

update to add provider selection

Browse files
Files changed (1) hide show
  1. app.py +48 -27
app.py CHANGED
@@ -227,10 +227,11 @@ DEMO_LIST = [
227
 
228
  # HF Inference Client
229
  HF_TOKEN = os.getenv('HF_TOKEN')
 
 
230
 
231
- def get_inference_client(model_id):
232
- """Return an InferenceClient with provider based on model_id."""
233
- provider = "groq" if model_id == "moonshotai/Kimi-K2-Instruct" else "auto"
234
  return InferenceClient(
235
  provider=provider,
236
  api_key=HF_TOKEN,
@@ -940,20 +941,24 @@ The HTML code above contains the complete original website structure with all im
940
  except Exception as e:
941
  return f"Error extracting website content: {str(e)}"
942
 
943
- def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[History], _current_model: Dict, enable_search: bool = False, language: str = "html"):
944
  if query is None:
945
  query = ''
946
  if _history is None:
947
  _history = []
948
-
 
 
 
 
949
  # Check if there's existing HTML content in history to determine if this is a modification request
950
  has_existing_html = False
951
- if _history:
952
- # Check the last assistant message for HTML content
953
- last_assistant_msg = _history[-1][1] if len(_history) > 0 else ""
954
  if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
955
  has_existing_html = True
956
-
957
  # Choose system prompt based on context
958
  if has_existing_html:
959
  # Use follow-up prompt for modifying existing HTML
@@ -964,9 +969,9 @@ def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optio
964
  system_prompt = HTML_SYSTEM_PROMPT_WITH_SEARCH if enable_search else HTML_SYSTEM_PROMPT
965
  else:
966
  system_prompt = GENERIC_SYSTEM_PROMPT_WITH_SEARCH.format(language=language) if enable_search else GENERIC_SYSTEM_PROMPT.format(language=language)
967
-
968
  messages = history_to_messages(_history, system_prompt)
969
-
970
  # Extract file text and append to query if file is present
971
  file_text = ""
972
  if file:
@@ -974,7 +979,7 @@ def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optio
974
  if file_text:
975
  file_text = file_text[:5000] # Limit to 5000 chars for prompt size
976
  query = f"{query}\n\n[Reference file content below]\n{file_text}"
977
-
978
  # Extract website content and append to query if website URL is present
979
  website_text = ""
980
  if website_url and website_url.strip():
@@ -994,12 +999,12 @@ Since I couldn't extract the website content, please provide additional details
994
 
995
  This will help me create a better design for you."""
996
  query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"
997
-
998
  # Enhance query with search if enabled
999
  enhanced_query = enhance_query_with_search(query, enable_search)
1000
-
1001
  # Use dynamic client based on selected model
1002
- client = get_inference_client(_current_model["id"])
1003
 
1004
  if image is not None:
1005
  messages.append(create_multimodal_message(enhanced_query, image))
@@ -1014,7 +1019,8 @@ This will help me create a better design for you."""
1014
  )
1015
  content = ""
1016
  for chunk in completion:
1017
- if chunk.choices[0].delta.content:
 
1018
  content += chunk.choices[0].delta.content
1019
  clean_code = remove_code_block(content)
1020
  search_status = " (with web search)" if enable_search and tavily_client else ""
@@ -1027,7 +1033,7 @@ This will help me create a better design for you."""
1027
  sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
1028
  }
1029
  else:
1030
- last_html = _history[-1][1] if _history else ""
1031
  modified_html = apply_search_replace_changes(last_html, clean_code)
1032
  clean_html = remove_code_block(modified_html)
1033
  yield {
@@ -1041,6 +1047,8 @@ This will help me create a better design for you."""
1041
  history_output: history_to_chatbot_messages(_history),
1042
  sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
1043
  }
 
 
1044
  # Handle response based on whether this is a modification or new generation
1045
  if has_existing_html:
1046
  # Fallback: If the model returns a full HTML file, use it directly
@@ -1048,14 +1056,11 @@ This will help me create a better design for you."""
1048
  if final_code.strip().startswith("<!DOCTYPE html>") or final_code.strip().startswith("<html"):
1049
  clean_html = final_code
1050
  else:
1051
- last_html = _history[-1][1] if _history else ""
1052
  modified_html = apply_search_replace_changes(last_html, final_code)
1053
  clean_html = remove_code_block(modified_html)
1054
  # Update history with the cleaned HTML
1055
- _history = messages_to_history(messages + [{
1056
- 'role': 'assistant',
1057
- 'content': clean_html
1058
- }])
1059
  yield {
1060
  code_output: clean_html,
1061
  history: _history,
@@ -1064,10 +1069,7 @@ This will help me create a better design for you."""
1064
  }
1065
  else:
1066
  # Regular generation - use the content as is
1067
- _history = messages_to_history(messages + [{
1068
- 'role': 'assistant',
1069
- 'content': content
1070
- }])
1071
  yield {
1072
  code_output: remove_code_block(content),
1073
  history: _history,
@@ -1156,6 +1158,16 @@ with gr.Blocks(
1156
  label="Model",
1157
  visible=True # Always visible
1158
  )
 
 
 
 
 
 
 
 
 
 
1159
  gr.Markdown("**Quick start**", visible=True)
1160
  with gr.Column(visible=True) as quick_examples_col:
1161
  for i, demo_item in enumerate(DEMO_LIST[:3]):
@@ -1251,7 +1263,7 @@ with gr.Blocks(
1251
 
1252
  btn.click(
1253
  generation_code,
1254
- inputs=[input, image_input, file_input, website_url_input, setting, history, current_model, search_toggle, language_dropdown],
1255
  outputs=[code_output, history, sandbox, history_output]
1256
  )
1257
  # Update preview when code or language changes
@@ -1259,5 +1271,14 @@ with gr.Blocks(
1259
  language_dropdown.change(preview_logic, inputs=[code_output, language_dropdown], outputs=sandbox)
1260
  clear_btn.click(clear_history, outputs=[history, history_output, file_input, website_url_input])
1261
 
 
 
 
 
 
 
 
 
 
1262
  if __name__ == "__main__":
1263
  demo.queue(api_open=False, default_concurrency_limit=20).launch(ssr_mode=True, mcp_server=False, show_api=False)
 
227
 
228
  # HF Inference Client
229
  HF_TOKEN = os.getenv('HF_TOKEN')
230
+ if not HF_TOKEN:
231
+ raise RuntimeError("HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token.")
232
 
233
+ def get_inference_client(model_id, provider="auto"):
234
+ """Return an InferenceClient with provider based on model_id and user selection."""
 
235
  return InferenceClient(
236
  provider=provider,
237
  api_key=HF_TOKEN,
 
941
  except Exception as e:
942
  return f"Error extracting website content: {str(e)}"
943
 
944
+ def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[History], _current_model: Dict, enable_search: bool = False, language: str = "html", provider: str = "auto"):
945
  if query is None:
946
  query = ''
947
  if _history is None:
948
  _history = []
949
+ # Ensure _history is always a list of lists with at least 2 elements per item
950
+ if not isinstance(_history, list):
951
+ _history = []
952
+ _history = [h for h in _history if isinstance(h, list) and len(h) == 2]
953
+
954
  # Check if there's existing HTML content in history to determine if this is a modification request
955
  has_existing_html = False
956
+ last_assistant_msg = ""
957
+ if _history and len(_history[-1]) > 1:
958
+ last_assistant_msg = _history[-1][1]
959
  if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
960
  has_existing_html = True
961
+
962
  # Choose system prompt based on context
963
  if has_existing_html:
964
  # Use follow-up prompt for modifying existing HTML
 
969
  system_prompt = HTML_SYSTEM_PROMPT_WITH_SEARCH if enable_search else HTML_SYSTEM_PROMPT
970
  else:
971
  system_prompt = GENERIC_SYSTEM_PROMPT_WITH_SEARCH.format(language=language) if enable_search else GENERIC_SYSTEM_PROMPT.format(language=language)
972
+
973
  messages = history_to_messages(_history, system_prompt)
974
+
975
  # Extract file text and append to query if file is present
976
  file_text = ""
977
  if file:
 
979
  if file_text:
980
  file_text = file_text[:5000] # Limit to 5000 chars for prompt size
981
  query = f"{query}\n\n[Reference file content below]\n{file_text}"
982
+
983
  # Extract website content and append to query if website URL is present
984
  website_text = ""
985
  if website_url and website_url.strip():
 
999
 
1000
  This will help me create a better design for you."""
1001
  query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"
1002
+
1003
  # Enhance query with search if enabled
1004
  enhanced_query = enhance_query_with_search(query, enable_search)
1005
+
1006
  # Use dynamic client based on selected model
1007
+ client = get_inference_client(_current_model["id"], provider)
1008
 
1009
  if image is not None:
1010
  messages.append(create_multimodal_message(enhanced_query, image))
 
1019
  )
1020
  content = ""
1021
  for chunk in completion:
1022
+ # Only process if chunk.choices is non-empty
1023
+ if hasattr(chunk, "choices") and chunk.choices and hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
1024
  content += chunk.choices[0].delta.content
1025
  clean_code = remove_code_block(content)
1026
  search_status = " (with web search)" if enable_search and tavily_client else ""
 
1033
  sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
1034
  }
1035
  else:
1036
+ last_html = _history[-1][1] if _history and len(_history[-1]) > 1 else ""
1037
  modified_html = apply_search_replace_changes(last_html, clean_code)
1038
  clean_html = remove_code_block(modified_html)
1039
  yield {
 
1047
  history_output: history_to_chatbot_messages(_history),
1048
  sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
1049
  }
1050
+ # Skip chunks with empty choices (end of stream)
1051
+ # Do not treat as error
1052
  # Handle response based on whether this is a modification or new generation
1053
  if has_existing_html:
1054
  # Fallback: If the model returns a full HTML file, use it directly
 
1056
  if final_code.strip().startswith("<!DOCTYPE html>") or final_code.strip().startswith("<html"):
1057
  clean_html = final_code
1058
  else:
1059
+ last_html = _history[-1][1] if _history and len(_history[-1]) > 1 else ""
1060
  modified_html = apply_search_replace_changes(last_html, final_code)
1061
  clean_html = remove_code_block(modified_html)
1062
  # Update history with the cleaned HTML
1063
+ _history.append([query, clean_html])
 
 
 
1064
  yield {
1065
  code_output: clean_html,
1066
  history: _history,
 
1069
  }
1070
  else:
1071
  # Regular generation - use the content as is
1072
+ _history.append([query, content])
 
 
 
1073
  yield {
1074
  code_output: remove_code_block(content),
1075
  history: _history,
 
1158
  label="Model",
1159
  visible=True # Always visible
1160
  )
1161
+ provider_choices = [
1162
+ "auto", "black-forest-labs", "cerebras", "cohere", "fal-ai", "featherless-ai", "fireworks-ai", "groq", "hf-inference", "hyperbolic", "nebius", "novita", "nscale", "openai", "replicate", "sambanova", "together"
1163
+ ]
1164
+ provider_dropdown = gr.Dropdown(
1165
+ choices=provider_choices,
1166
+ value="auto",
1167
+ label="Provider",
1168
+ visible=True
1169
+ )
1170
+ provider_state = gr.State("auto")
1171
  gr.Markdown("**Quick start**", visible=True)
1172
  with gr.Column(visible=True) as quick_examples_col:
1173
  for i, demo_item in enumerate(DEMO_LIST[:3]):
 
1263
 
1264
  btn.click(
1265
  generation_code,
1266
+ inputs=[input, image_input, file_input, website_url_input, setting, history, current_model, search_toggle, language_dropdown, provider_state],
1267
  outputs=[code_output, history, sandbox, history_output]
1268
  )
1269
  # Update preview when code or language changes
 
1271
  language_dropdown.change(preview_logic, inputs=[code_output, language_dropdown], outputs=sandbox)
1272
  clear_btn.click(clear_history, outputs=[history, history_output, file_input, website_url_input])
1273
 
1274
+ def on_provider_change(provider):
1275
+ return provider
1276
+
1277
+ provider_dropdown.change(
1278
+ on_provider_change,
1279
+ inputs=provider_dropdown,
1280
+ outputs=provider_state
1281
+ )
1282
+
1283
  if __name__ == "__main__":
1284
  demo.queue(api_open=False, default_concurrency_limit=20).launch(ssr_mode=True, mcp_server=False, show_api=False)