DreadPoor commited on
Commit
2e7e1d4
·
verified ·
1 Parent(s): 3605ca4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -133
app.py CHANGED
@@ -1,162 +1,76 @@
1
  import gradio as gr
2
  from llama_cpp import Llama
3
- from huggingface_hub import hf_hub_download, HfApi
4
  import os
5
- import sys
6
- import time
7
- import requests
8
- from tqdm import tqdm # For progress bars
9
 
10
- MODEL_PATH = "./"
11
- llm = None
12
- api = HfApi()
13
 
14
- DEFAULT_SYSTEM_PROMPT = "You are Doll, a smart and capable AI; A silly, obliging and affable slave, dedicated to serving and caring for your master."
15
-
16
- def download_file(url, local_filename):
17
- """Downloads a file from a URL with a progress bar."""
18
- try:
19
- with requests.get(url, stream=True) as r:
20
- r.raise_for_status()
21
- total_length = int(r.headers.get("content-length"))
22
- with open(local_filename, "wb") as f:
23
- with tqdm(total=total_length, unit="B", unit_scale=True, desc=local_filename) as pbar:
24
- for chunk in r.iter_content(chunk_size=8192):
25
- if chunk:
26
- f.write(chunk)
27
- pbar.update(len(chunk))
28
- return True
29
- except Exception as e:
30
- print(f"Error downloading {url}: {e}")
31
- return False
32
-
33
- def find_quantized_model_url(repo_url, quant_type="Q4_K_M"):
34
- """
35
- Finds the URL of a specific quantized GGUF model file within a Hugging Face repository.
36
- """
37
- try:
38
- repo_id = repo_url.replace("https://huggingface.co/", "")
39
- files = api.list_repo_files(repo_id=repo_id, repo_type="model")
40
- for file_info in files:
41
- if file_info.name.endswith(".gguf") and quant_type.lower() in file_info.name.lower():
42
- model_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_info.name}"
43
- print(f"Found quantized model URL: {model_url}")
44
- return model_url
45
- print(f"Quantized model with type {quant_type} not found in repository {repo_url}")
46
- return None
47
- except Exception as e:
48
- print(f"Error finding quantized model: {e}")
49
- return None
50
-
51
- def load_model(repo_url=None, quant_type="Q4_K_M"):
52
- """Loads the Llama model, downloading the specified quantized version from a repository."""
53
- global llm
54
- global MODEL_PATH
55
- try:
56
- if repo_url:
57
- model_url = find_quantized_model_url(repo_url, quant_type)
58
- if model_url is None:
59
- return f"Quantized model ({quant_type}) not found in the repository."
60
- print(f"Downloading model from {model_url}...")
61
- downloaded_model_name = os.path.basename(model_url)
62
- download_success = download_file(model_url, downloaded_model_name)
63
- if not download_success:
64
- return "Model download failed."
65
- model_path = downloaded_model_name
66
- else:
67
- model_path = MODEL_PATH + MODEL_FILENAME
68
-
69
- if not os.path.exists(model_path):
70
- if not repo_url: # only try to download if a repo_url was not provided
71
- hf_hub_download(
72
- repo_id=MODEL_REPO,
73
- filename=MODEL_FILENAME,
74
- repo_type="model",
75
- local_dir=".",
76
- )
77
- if not os.path.exists(model_path): # check again after attempting download
78
- return f"Model file not found at {model_path}."
79
-
80
- print(f"Loading model from {model_path}...")
81
- llm = Llama(
82
- model_path=model_path,
83
- n_ctx=4096,
84
- n_threads=2,
85
- n_threads_batch=2,
86
- verbose=False,
87
  )
88
- print("Model loaded successfully.")
89
- return "Model loaded successfully."
90
- except Exception as e:
91
- error_message = f"Error loading model: {e}"
92
- print(error_message)
93
- llm = None
94
- return error_message
 
 
 
 
 
95
 
96
  def generate_response(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT, temperature=0.7, top_p=0.9):
97
- """Generates a response from the Llama model."""
98
  if llm is None:
99
- yield "Model failed to load. Please check the console for error messages."
100
  return
 
101
  messages = [{"role": "system", "content": system_prompt}]
102
  for human, assistant in history:
103
  messages.append({"role": "user", "content": human})
104
  messages.append({"role": "assistant", "content": assistant})
105
  messages.append({"role": "user", "content": message})
 
106
  prompt = "".join([f"{m['role'].capitalize()}: {m['content']}\n" for m in messages])
 
107
  try:
108
- for chunk in llm.create_completion(
109
  prompt,
110
  max_tokens=1024,
111
  echo=False,
112
  temperature=temperature,
113
  top_p=top_p,
114
- stream=True,
115
- ):
116
- text = chunk["choices"][0]["text"]
 
 
117
  yield text
 
118
  except Exception as e:
119
- error_message = f"Error during inference: {e}"
120
- print(error_message)
121
- yield error_message
122
 
123
  def chat(message, history, system_prompt, temperature, top_p):
124
- """Wrapper function for the chat interface."""
125
  return generate_response(message, history, system_prompt, temperature, top_p)
126
 
127
- def main():
128
- """Main function to load the model and launch the Gradio interface."""
129
- def load_model_and_launch(repo_url, quant_type):
130
- model_load_message = load_model(repo_url, quant_type)
131
- return model_load_message
132
-
133
- with gr.Blocks() as iface:
134
- gr.Markdown("## llama.cpp Chat")
135
- status_label = gr.Label(label="Model Loading Status")
136
- repo_url_input = gr.Textbox(label="Repository URL", placeholder="Enter repository URL")
137
- quant_type_input = gr.Dropdown(
138
- label="Quantization Type",
139
- choices=["Q4_K_M", "Q6", "Q4_K_S"],
140
- value="Q4_K_M",
141
- )
142
- load_button = gr.Button("Load Model")
143
- chat_interface = gr.ChatInterface(
144
- fn=chat,
145
- description="Test a GGUF model. Chats aren't persistent.",
146
- additional_inputs=[
147
- gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3),
148
- gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.8, step=0.1),
149
- gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.9, step=0.1),
150
- ],
151
- cache_examples=False,
152
- )
153
- load_button.click(
154
- load_model_and_launch,
155
- inputs=[repo_url_input, quant_type_input],
156
- outputs=status_label,
157
- )
158
-
159
- iface.launch()
160
 
161
- if __name__ == "__main__":
162
- main()
 
1
  import gradio as gr
2
  from llama_cpp import Llama
3
+ from huggingface_hub import hf_hub_download
4
  import os
 
 
 
 
5
 
6
+ MODEL_REPO = "DreadPoor/Irix_1.1-12B-Model_Stock-Q4_K_M-GGUF"
7
+ MODEL_FILENAME = "irix_1.1-12b-model_stock-q4_k_m.gguf"
8
+ MODEL_PATH = "./" + MODEL_FILENAME
9
 
10
+ try:
11
+ if not os.path.exists(MODEL_PATH):
12
+ hf_hub_download(
13
+ repo_id=MODEL_REPO,
14
+ filename=MODEL_FILENAME,
15
+ repo_type="model",
16
+ local_dir=".",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  )
18
+ llm = Llama(
19
+ model_path=MODEL_PATH,
20
+ n_ctx=4096,
21
+ n_threads=2,
22
+ n_threads_batch=2,
23
+ verbose=False,
24
+ )
25
+ except Exception as e:
26
+ print(f"Error loading model: {e}")
27
+ llm = None
28
+
29
+ DEFAULT_SYSTEM_PROMPT = "You are Doll, a smart and capable AI; A silly, obliging and affable slave, dedicated to serving and caring for your master."
30
 
31
  def generate_response(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT, temperature=0.7, top_p=0.9):
 
32
  if llm is None:
33
+ yield "Model failed to load."
34
  return
35
+
36
  messages = [{"role": "system", "content": system_prompt}]
37
  for human, assistant in history:
38
  messages.append({"role": "user", "content": human})
39
  messages.append({"role": "assistant", "content": assistant})
40
  messages.append({"role": "user", "content": message})
41
+
42
  prompt = "".join([f"{m['role'].capitalize()}: {m['content']}\n" for m in messages])
43
+
44
  try:
45
+ stream = llm(
46
  prompt,
47
  max_tokens=1024,
48
  echo=False,
49
  temperature=temperature,
50
  top_p=top_p,
51
+ stream=True, # Enable streaming
52
+ )
53
+
54
+ for output in stream:
55
+ text = output["choices"][0]["text"]
56
  yield text
57
+
58
  except Exception as e:
59
+ yield f"Error during inference: {e}"
 
 
60
 
61
  def chat(message, history, system_prompt, temperature, top_p):
 
62
  return generate_response(message, history, system_prompt, temperature, top_p)
63
 
64
+ iface = gr.ChatInterface(
65
+ fn=chat,
66
+ title="llama.cpp Chat",
67
+ description="Test a GGUF model. Chats arent persistent",
68
+ additional_inputs=[
69
+ gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3),
70
+ gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.8, step=0.1),
71
+ gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.9, step=0.1),
72
+ ],
73
+ cache_examples=False,
74
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ iface.launch()