DreadPoor commited on
Commit
27be339
·
verified ·
1 Parent(s): 370ab1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -25
app.py CHANGED
@@ -6,12 +6,13 @@ import sys
6
  import time
7
  import requests
8
  from tqdm import tqdm # For progress bars
 
9
 
10
  MODEL_PATH = "./" # Default model path
11
  llm = None # Initialize llm outside the try block
12
  api = HfApi() #initialize
13
 
14
- def download_file(url, local_filename):
15
  """Downloads a file from a URL with a progress bar."""
16
  try:
17
  with requests.get(url, stream=True) as r:
@@ -28,44 +29,74 @@ def download_file(url, local_filename):
28
  error_message = f"Error downloading {url}: {e}"
29
  print(error_message)
30
  return False # Return False on failure
31
-
32
- def find_quantized_model_url(repo_url, quant_type="Q4_K_M"):
33
  """
34
- Finds the URL of a specific quantized GGUF model file within a Hugging Face repository.
35
 
36
  Args:
37
  repo_url (str): The URL of the Hugging Face repository.
38
- quant_type (str): The quantization type (e.g., "Q4_K_M", "Q8_0"). Defaults to "Q4_K_M".
39
 
40
  Returns:
41
- str: The URL of the quantized model file, or None if not found.
 
42
  """
 
43
  try:
44
  repo_id = repo_url.replace("https://huggingface.co/", "")
45
  files = api.list_repo_files(repo_id=repo_id, repo_type="model")
46
  for file_info in files:
47
- if file_info.name.endswith(".gguf") and quant_type.lower() in file_info.name.lower():
48
- # Construct the full URL. This is crucial.
49
- model_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_info.name}"
50
- print(f"Found quantized model URL: {model_url}")
51
- return model_url
52
- print(f"Quantized model with type {quant_type} not found in repository {repo_url}")
53
- return None
54
  except Exception as e:
55
- print(f"Error finding quantized model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return None
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def load_model(repo_url=None, quant_type="Q4_K_M"):
60
- """Loads the Llama model, downloading the specified quantized version from a repository."""
61
  global llm
62
  global MODEL_PATH # Use the global MODEL_PATH
63
  try:
64
  if repo_url:
65
- # 1. Find the quantized model URL
66
- model_url = find_quantized_model_url(repo_url, quant_type)
67
  if model_url is None:
68
- return f"Quantized model ({quant_type}) not found in the repository."
69
 
70
  # 2. Download the model
71
  print(f"Downloading model from {model_url}...")
@@ -101,7 +132,7 @@ def load_model(repo_url=None, quant_type="Q4_K_M"):
101
  DEFAULT_SYSTEM_PROMPT = "You are Doll, a smart and capable AI; A silly, obliging and affable slave, dedicated to serving and caring for your master."
102
 
103
 
104
- def generate_response(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT, temperature=0.9, top_p=0.9):
105
  """Generates a response from the Llama model."""
106
  if llm is None:
107
  yield "Model failed to load. Please check the console for error messages."
@@ -130,15 +161,15 @@ def generate_response(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT, tem
130
  yield error_message
131
 
132
 
133
- def chat(message, history, system_prompt, temperature, top_p):
134
  """Wrapper function for the chat interface."""
135
  return generate_response(message, history, system_prompt, temperature, top_p)
136
 
137
 
138
  def main():
139
  """Main function to load the model and launch the Gradio interface."""
140
- # Use a function to load the model, and pass the model_url from the text box.
141
- def load_model_and_launch(repo_url, quant_type):
142
  model_load_message = load_model(repo_url, quant_type)
143
  return model_load_message
144
 
@@ -148,7 +179,7 @@ def main():
148
  repo_url_input = gr.Textbox(label="Repository URL", placeholder="Enter repository URL")
149
  quant_type_input = gr.Dropdown(
150
  label="Quantization Type",
151
- choices=["Q4_K_M", "Q6_K", "Q4_K_S"], # Add more options as needed
152
  value="Q4_K_M", # Default value
153
  )
154
  load_button = gr.Button("Load Model") # added load button
@@ -172,4 +203,4 @@ def main():
172
 
173
 
174
  if __name__ == "__main__":
175
- main()
 
6
  import time
7
  import requests
8
  from tqdm import tqdm # For progress bars
9
+ from typing import Optional, List, Dict
10
 
11
  MODEL_PATH = "./" # Default model path
12
  llm = None # Initialize llm outside the try block
13
  api = HfApi() #initialize
14
 
15
+ def download_file(url: str, local_filename: str) -> bool:
16
  """Downloads a file from a URL with a progress bar."""
17
  try:
18
  with requests.get(url, stream=True) as r:
 
29
  error_message = f"Error downloading {url}: {e}"
30
  print(error_message)
31
  return False # Return False on failure
32
+
33
+ def get_gguf_files_from_repo(repo_url: str) -> List[Dict[str, str]]:
34
  """
35
+ Retrieves a list of GGUF files from a Hugging Face repository.
36
 
37
  Args:
38
  repo_url (str): The URL of the Hugging Face repository.
 
39
 
40
  Returns:
41
+ List[Dict[str, str]]: A list of dictionaries, where each dictionary contains the file name
42
+ and its full URL. Returns an empty list if no GGUF files are found or an error occurs.
43
  """
44
+ gguf_files: List[Dict[str, str]] = []
45
  try:
46
  repo_id = repo_url.replace("https://huggingface.co/", "")
47
  files = api.list_repo_files(repo_id=repo_id, repo_type="model")
48
  for file_info in files:
49
+ if file_info.name.endswith(".gguf"):
50
+ file_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_info.name}"
51
+ gguf_files.append({"name": file_info.name, "url": file_url})
52
+ return gguf_files
 
 
 
53
  except Exception as e:
54
+ print(f"Error retrieving GGUF files from {repo_url}: {e}")
55
+ return []
56
+
57
+ def find_best_gguf_model(repo_url: str, quant_type: str = "Q4_K_M") -> Optional[str]:
58
+ """
59
+ Intelligently finds the "best" GGUF model file from a Hugging Face repository,
60
+ prioritizing the specified quantization type.
61
+
62
+ Args:
63
+ repo_url (str): The URL of the Hugging Face repository.
64
+ quant_type (str): The desired quantization type (e.g., "Q4_K_M", "Q8_0").
65
+ Defaults to "Q4_K_M".
66
+
67
+ Returns:
68
+ Optional[str]: The URL of the best GGUF model file, or None if no suitable file is found.
69
+ """
70
+ gguf_files = get_gguf_files_from_repo(repo_url)
71
+ if not gguf_files:
72
  return None
73
 
74
+ # 1. Priority to exact quant type match
75
+ for file_data in gguf_files:
76
+ if quant_type.lower() in file_data["name"].lower():
77
+ print(f"Found exact match: {file_data['url']}")
78
+ return file_data["url"]
79
+
80
+ # 2. Fallback: Find any GGUF file (if no exact match) - Less ideal, but handles cases where the user doesn't specify.
81
+ if gguf_files:
82
+ print(f"Found a GGUF file: {gguf_files[0]['url']}")
83
+ return gguf_files[0]["url"]
84
+
85
+ print(f"No suitable GGUF model found in {repo_url} for quant type {quant_type}")
86
+ return None
87
+
88
+
89
 
90
+ def load_model(repo_url: Optional[str] = None, quant_type: str = "Q4_K_M") -> str:
91
+ """Loads the Llama model, downloading the specified version from a repository."""
92
  global llm
93
  global MODEL_PATH # Use the global MODEL_PATH
94
  try:
95
  if repo_url:
96
+ # 1. Find the model URL
97
+ model_url = find_best_gguf_model(repo_url, quant_type)
98
  if model_url is None:
99
+ return f"No suitable model found in the repository."
100
 
101
  # 2. Download the model
102
  print(f"Downloading model from {model_url}...")
 
132
  DEFAULT_SYSTEM_PROMPT = "You are Doll, a smart and capable AI; A silly, obliging and affable slave, dedicated to serving and caring for your master."
133
 
134
 
135
+ def generate_response(message: str, history: List[List[str]], system_prompt: str = DEFAULT_SYSTEM_PROMPT, temperature: float = 0.9, top_p: float = 0.9):
136
  """Generates a response from the Llama model."""
137
  if llm is None:
138
  yield "Model failed to load. Please check the console for error messages."
 
161
  yield error_message
162
 
163
 
164
+ def chat(message: str, history: List[List[str]], system_prompt: str, temperature: float, top_p: float) -> str:
165
  """Wrapper function for the chat interface."""
166
  return generate_response(message, history, system_prompt, temperature, top_p)
167
 
168
 
169
  def main():
170
  """Main function to load the model and launch the Gradio interface."""
171
+ # Use a function to load the model, and pass the repo_url from the input.
172
+ def load_model_and_launch(repo_url: str, quant_type: str):
173
  model_load_message = load_model(repo_url, quant_type)
174
  return model_load_message
175
 
 
179
  repo_url_input = gr.Textbox(label="Repository URL", placeholder="Enter repository URL")
180
  quant_type_input = gr.Dropdown(
181
  label="Quantization Type",
182
+ choices=["Q4_K_M", "Q8_0", "Q4_K_S"], # Add more options as needed
183
  value="Q4_K_M", # Default value
184
  )
185
  load_button = gr.Button("Load Model") # added load button
 
203
 
204
 
205
  if __name__ == "__main__":
206
+ main()