DreadPoor commited on
Commit
3605ca4
·
verified ·
1 Parent(s): 27be339

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -90
app.py CHANGED
@@ -6,111 +6,76 @@ import sys
6
  import time
7
  import requests
8
  from tqdm import tqdm # For progress bars
9
- from typing import Optional, List, Dict
10
 
11
- MODEL_PATH = "./" # Default model path
12
- llm = None # Initialize llm outside the try block
13
- api = HfApi() #initialize
14
 
15
- def download_file(url: str, local_filename: str) -> bool:
 
 
16
  """Downloads a file from a URL with a progress bar."""
17
  try:
18
  with requests.get(url, stream=True) as r:
19
- r.raise_for_status() # Raise an exception for bad status codes
20
  total_length = int(r.headers.get("content-length"))
21
  with open(local_filename, "wb") as f:
22
  with tqdm(total=total_length, unit="B", unit_scale=True, desc=local_filename) as pbar:
23
  for chunk in r.iter_content(chunk_size=8192):
24
- if chunk: # filter out keep-alive new chunks
25
  f.write(chunk)
26
  pbar.update(len(chunk))
27
- return True # Return True on success
28
  except Exception as e:
29
- error_message = f"Error downloading {url}: {e}"
30
- print(error_message)
31
- return False # Return False on failure
32
 
33
- def get_gguf_files_from_repo(repo_url: str) -> List[Dict[str, str]]:
34
  """
35
- Retrieves a list of GGUF files from a Hugging Face repository.
36
-
37
- Args:
38
- repo_url (str): The URL of the Hugging Face repository.
39
-
40
- Returns:
41
- List[Dict[str, str]]: A list of dictionaries, where each dictionary contains the file name
42
- and its full URL. Returns an empty list if no GGUF files are found or an error occurs.
43
  """
44
- gguf_files: List[Dict[str, str]] = []
45
  try:
46
  repo_id = repo_url.replace("https://huggingface.co/", "")
47
  files = api.list_repo_files(repo_id=repo_id, repo_type="model")
48
  for file_info in files:
49
- if file_info.name.endswith(".gguf"):
50
- file_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_info.name}"
51
- gguf_files.append({"name": file_info.name, "url": file_url})
52
- return gguf_files
 
 
53
  except Exception as e:
54
- print(f"Error retrieving GGUF files from {repo_url}: {e}")
55
- return []
56
-
57
- def find_best_gguf_model(repo_url: str, quant_type: str = "Q4_K_M") -> Optional[str]:
58
- """
59
- Intelligently finds the "best" GGUF model file from a Hugging Face repository,
60
- prioritizing the specified quantization type.
61
-
62
- Args:
63
- repo_url (str): The URL of the Hugging Face repository.
64
- quant_type (str): The desired quantization type (e.g., "Q4_K_M", "Q8_0").
65
- Defaults to "Q4_K_M".
66
-
67
- Returns:
68
- Optional[str]: The URL of the best GGUF model file, or None if no suitable file is found.
69
- """
70
- gguf_files = get_gguf_files_from_repo(repo_url)
71
- if not gguf_files:
72
  return None
73
 
74
- # 1. Priority to exact quant type match
75
- for file_data in gguf_files:
76
- if quant_type.lower() in file_data["name"].lower():
77
- print(f"Found exact match: {file_data['url']}")
78
- return file_data["url"]
79
-
80
- # 2. Fallback: Find any GGUF file (if no exact match) - Less ideal, but handles cases where the user doesn't specify.
81
- if gguf_files:
82
- print(f"Found a GGUF file: {gguf_files[0]['url']}")
83
- return gguf_files[0]["url"]
84
-
85
- print(f"No suitable GGUF model found in {repo_url} for quant type {quant_type}")
86
- return None
87
-
88
-
89
-
90
- def load_model(repo_url: Optional[str] = None, quant_type: str = "Q4_K_M") -> str:
91
- """Loads the Llama model, downloading the specified version from a repository."""
92
  global llm
93
- global MODEL_PATH # Use the global MODEL_PATH
94
  try:
95
  if repo_url:
96
- # 1. Find the model URL
97
- model_url = find_best_gguf_model(repo_url, quant_type)
98
  if model_url is None:
99
- return f"No suitable model found in the repository."
100
-
101
- # 2. Download the model
102
  print(f"Downloading model from {model_url}...")
103
  downloaded_model_name = os.path.basename(model_url)
104
  download_success = download_file(model_url, downloaded_model_name)
105
  if not download_success:
106
  return "Model download failed."
107
- model_path = downloaded_model_name #set model path
108
-
109
  else:
110
  model_path = MODEL_PATH + MODEL_FILENAME
111
 
112
  if not os.path.exists(model_path):
113
- return f"Model file not found at {model_path}."
 
 
 
 
 
 
 
 
114
 
115
  print(f"Loading model from {model_path}...")
116
  llm = Llama(
@@ -128,16 +93,11 @@ def load_model(repo_url: Optional[str] = None, quant_type: str = "Q4_K_M") -> st
128
  llm = None
129
  return error_message
130
 
131
-
132
- DEFAULT_SYSTEM_PROMPT = "You are Doll, a smart and capable AI; A silly, obliging and affable slave, dedicated to serving and caring for your master."
133
-
134
-
135
- def generate_response(message: str, history: List[List[str]], system_prompt: str = DEFAULT_SYSTEM_PROMPT, temperature: float = 0.9, top_p: float = 0.9):
136
  """Generates a response from the Llama model."""
137
  if llm is None:
138
  yield "Model failed to load. Please check the console for error messages."
139
  return
140
-
141
  messages = [{"role": "system", "content": system_prompt}]
142
  for human, assistant in history:
143
  messages.append({"role": "user", "content": human})
@@ -160,29 +120,26 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
160
  print(error_message)
161
  yield error_message
162
 
163
-
164
- def chat(message: str, history: List[List[str]], system_prompt: str, temperature: float, top_p: float) -> str:
165
  """Wrapper function for the chat interface."""
166
  return generate_response(message, history, system_prompt, temperature, top_p)
167
 
168
-
169
  def main():
170
  """Main function to load the model and launch the Gradio interface."""
171
- # Use a function to load the model, and pass the repo_url from the input.
172
- def load_model_and_launch(repo_url: str, quant_type: str):
173
  model_load_message = load_model(repo_url, quant_type)
174
  return model_load_message
175
 
176
  with gr.Blocks() as iface:
177
  gr.Markdown("## llama.cpp Chat")
178
- status_label = gr.Label(label="Model Loading Status") # Add a status label
179
  repo_url_input = gr.Textbox(label="Repository URL", placeholder="Enter repository URL")
180
  quant_type_input = gr.Dropdown(
181
  label="Quantization Type",
182
- choices=["Q4_K_M", "Q8_0", "Q4_K_S"], # Add more options as needed
183
- value="Q4_K_M", # Default value
184
  )
185
- load_button = gr.Button("Load Model") # added load button
186
  chat_interface = gr.ChatInterface(
187
  fn=chat,
188
  description="Test a GGUF model. Chats aren't persistent.",
@@ -193,14 +150,13 @@ def main():
193
  ],
194
  cache_examples=False,
195
  )
196
- load_button.click( # on click, load the model.
197
  load_model_and_launch,
198
- inputs=[repo_url_input, quant_type_input], # Get input from the textbox
199
- outputs=status_label, # update the status label.
200
  )
201
 
202
  iface.launch()
203
 
204
-
205
  if __name__ == "__main__":
206
- main()
 
6
  import time
7
  import requests
8
  from tqdm import tqdm # For progress bars
 
9
 
10
+ MODEL_PATH = "./"
11
+ llm = None
12
+ api = HfApi()
13
 
14
+ DEFAULT_SYSTEM_PROMPT = "You are Doll, a smart and capable AI; A silly, obliging and affable slave, dedicated to serving and caring for your master."
15
+
16
+ def download_file(url, local_filename):
17
  """Downloads a file from a URL with a progress bar."""
18
  try:
19
  with requests.get(url, stream=True) as r:
20
+ r.raise_for_status()
21
  total_length = int(r.headers.get("content-length"))
22
  with open(local_filename, "wb") as f:
23
  with tqdm(total=total_length, unit="B", unit_scale=True, desc=local_filename) as pbar:
24
  for chunk in r.iter_content(chunk_size=8192):
25
+ if chunk:
26
  f.write(chunk)
27
  pbar.update(len(chunk))
28
+ return True
29
  except Exception as e:
30
+ print(f"Error downloading {url}: {e}")
31
+ return False
 
32
 
33
+ def find_quantized_model_url(repo_url, quant_type="Q4_K_M"):
34
  """
35
+ Finds the URL of a specific quantized GGUF model file within a Hugging Face repository.
 
 
 
 
 
 
 
36
  """
 
37
  try:
38
  repo_id = repo_url.replace("https://huggingface.co/", "")
39
  files = api.list_repo_files(repo_id=repo_id, repo_type="model")
40
  for file_info in files:
41
+ if file_info.name.endswith(".gguf") and quant_type.lower() in file_info.name.lower():
42
+ model_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_info.name}"
43
+ print(f"Found quantized model URL: {model_url}")
44
+ return model_url
45
+ print(f"Quantized model with type {quant_type} not found in repository {repo_url}")
46
+ return None
47
  except Exception as e:
48
+ print(f"Error finding quantized model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  return None
50
 
51
+ def load_model(repo_url=None, quant_type="Q4_K_M"):
52
+ """Loads the Llama model, downloading the specified quantized version from a repository."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  global llm
54
+ global MODEL_PATH
55
  try:
56
  if repo_url:
57
+ model_url = find_quantized_model_url(repo_url, quant_type)
 
58
  if model_url is None:
59
+ return f"Quantized model ({quant_type}) not found in the repository."
 
 
60
  print(f"Downloading model from {model_url}...")
61
  downloaded_model_name = os.path.basename(model_url)
62
  download_success = download_file(model_url, downloaded_model_name)
63
  if not download_success:
64
  return "Model download failed."
65
+ model_path = downloaded_model_name
 
66
  else:
67
  model_path = MODEL_PATH + MODEL_FILENAME
68
 
69
  if not os.path.exists(model_path):
70
+ if not repo_url: # only try to download if a repo_url was not provided
71
+ hf_hub_download(
72
+ repo_id=MODEL_REPO,
73
+ filename=MODEL_FILENAME,
74
+ repo_type="model",
75
+ local_dir=".",
76
+ )
77
+ if not os.path.exists(model_path): # check again after attempting download
78
+ return f"Model file not found at {model_path}."
79
 
80
  print(f"Loading model from {model_path}...")
81
  llm = Llama(
 
93
  llm = None
94
  return error_message
95
 
96
+ def generate_response(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT, temperature=0.7, top_p=0.9):
 
 
 
 
97
  """Generates a response from the Llama model."""
98
  if llm is None:
99
  yield "Model failed to load. Please check the console for error messages."
100
  return
 
101
  messages = [{"role": "system", "content": system_prompt}]
102
  for human, assistant in history:
103
  messages.append({"role": "user", "content": human})
 
120
  print(error_message)
121
  yield error_message
122
 
123
+ def chat(message, history, system_prompt, temperature, top_p):
 
124
  """Wrapper function for the chat interface."""
125
  return generate_response(message, history, system_prompt, temperature, top_p)
126
 
 
127
  def main():
128
  """Main function to load the model and launch the Gradio interface."""
129
+ def load_model_and_launch(repo_url, quant_type):
 
130
  model_load_message = load_model(repo_url, quant_type)
131
  return model_load_message
132
 
133
  with gr.Blocks() as iface:
134
  gr.Markdown("## llama.cpp Chat")
135
+ status_label = gr.Label(label="Model Loading Status")
136
  repo_url_input = gr.Textbox(label="Repository URL", placeholder="Enter repository URL")
137
  quant_type_input = gr.Dropdown(
138
  label="Quantization Type",
139
+ choices=["Q4_K_M", "Q6", "Q4_K_S"],
140
+ value="Q4_K_M",
141
  )
142
+ load_button = gr.Button("Load Model")
143
  chat_interface = gr.ChatInterface(
144
  fn=chat,
145
  description="Test a GGUF model. Chats aren't persistent.",
 
150
  ],
151
  cache_examples=False,
152
  )
153
+ load_button.click(
154
  load_model_and_launch,
155
+ inputs=[repo_url_input, quant_type_input],
156
+ outputs=status_label,
157
  )
158
 
159
  iface.launch()
160
 
 
161
  if __name__ == "__main__":
162
+ main()