Spaces:
Sleeping
Sleeping
import gradio as gr | |
from llama_cpp import Llama | |
from huggingface_hub import hf_hub_download, HfApi | |
import os | |
import sys | |
import time | |
import requests | |
from tqdm import tqdm # For progress bars | |
MODEL_PATH = "./" | |
llm = None | |
api = HfApi() | |
DEFAULT_SYSTEM_PROMPT = "You are Doll, a smart and capable AI; A silly, obliging and affable slave, dedicated to serving and caring for your master." | |
def download_file(url, local_filename): | |
"""Downloads a file from a URL with a progress bar.""" | |
try: | |
with requests.get(url, stream=True) as r: | |
r.raise_for_status() | |
total_length = int(r.headers.get("content-length")) | |
with open(local_filename, "wb") as f: | |
with tqdm(total=total_length, unit="B", unit_scale=True, desc=local_filename) as pbar: | |
for chunk in r.iter_content(chunk_size=8192): | |
if chunk: | |
f.write(chunk) | |
pbar.update(len(chunk)) | |
return True | |
except Exception as e: | |
print(f"Error downloading {url}: {e}") | |
return False | |
def find_quantized_model_url(repo_url, quant_type="Q4_K_M"): | |
""" | |
Finds the URL of a specific quantized GGUF model file within a Hugging Face repository. | |
""" | |
try: | |
repo_id = repo_url.replace("https://huggingface.co/", "") | |
files = api.list_repo_files(repo_id=repo_id, repo_type="model") | |
for file_info in files: | |
if file_info.name.endswith(".gguf") and quant_type.lower() in file_info.name.lower(): | |
model_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_info.name}" | |
print(f"Found quantized model URL: {model_url}") | |
return model_url | |
print(f"Quantized model with type {quant_type} not found in repository {repo_url}") | |
return None | |
except Exception as e: | |
print(f"Error finding quantized model: {e}") | |
return None | |
def load_model(repo_url=None, quant_type="Q4_K_M"): | |
"""Loads the Llama model, downloading the specified quantized version from a repository.""" | |
global llm | |
global MODEL_PATH | |
try: | |
if repo_url: | |
model_url = find_quantized_model_url(repo_url, quant_type) | |
if model_url is None: | |
return f"Quantized model ({quant_type}) not found in the repository." | |
print(f"Downloading model from {model_url}...") | |
downloaded_model_name = os.path.basename(model_url) | |
download_success = download_file(model_url, downloaded_model_name) | |
if not download_success: | |
return "Model download failed." | |
model_path = downloaded_model_name | |
else: | |
model_path = MODEL_PATH + MODEL_FILENAME | |
if not os.path.exists(model_path): | |
if not repo_url: # only try to download if a repo_url was not provided | |
hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILENAME, | |
repo_type="model", | |
local_dir=".", | |
) | |
if not os.path.exists(model_path): # check again after attempting download | |
return f"Model file not found at {model_path}." | |
print(f"Loading model from {model_path}...") | |
llm = Llama( | |
model_path=model_path, | |
n_ctx=4096, | |
n_threads=2, | |
n_threads_batch=2, | |
verbose=False, | |
) | |
print("Model loaded successfully.") | |
return "Model loaded successfully." | |
except Exception as e: | |
error_message = f"Error loading model: {e}" | |
print(error_message) | |
llm = None | |
return error_message | |
def generate_response(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT, temperature=0.7, top_p=0.9): | |
"""Generates a response from the Llama model.""" | |
if llm is None: | |
yield "Model failed to load. Please check the console for error messages." | |
return | |
messages = [{"role": "system", "content": system_prompt}] | |
for human, assistant in history: | |
messages.append({"role": "user", "content": human}) | |
messages.append({"role": "assistant", "content": assistant}) | |
messages.append({"role": "user", "content": message}) | |
prompt = "".join([f"{m['role'].capitalize()}: {m['content']}\n" for m in messages]) | |
try: | |
for chunk in llm.create_completion( | |
prompt, | |
max_tokens=1024, | |
echo=False, | |
temperature=temperature, | |
top_p=top_p, | |
stream=True, | |
): | |
text = chunk["choices"][0]["text"] | |
yield text | |
except Exception as e: | |
error_message = f"Error during inference: {e}" | |
print(error_message) | |
yield error_message | |
def chat(message, history, system_prompt, temperature, top_p): | |
"""Wrapper function for the chat interface.""" | |
return generate_response(message, history, system_prompt, temperature, top_p) | |
def main(): | |
"""Main function to load the model and launch the Gradio interface.""" | |
def load_model_and_launch(repo_url, quant_type): | |
model_load_message = load_model(repo_url, quant_type) | |
return model_load_message | |
with gr.Blocks() as iface: | |
gr.Markdown("## llama.cpp Chat") | |
status_label = gr.Label(label="Model Loading Status") | |
repo_url_input = gr.Textbox(label="Repository URL", placeholder="Enter repository URL") | |
quant_type_input = gr.Dropdown( | |
label="Quantization Type", | |
choices=["Q4_K_M", "Q6", "Q4_K_S"], | |
value="Q4_K_M", | |
) | |
load_button = gr.Button("Load Model") | |
chat_interface = gr.ChatInterface( | |
fn=chat, | |
description="Test a GGUF model. Chats aren't persistent.", | |
additional_inputs=[ | |
gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3), | |
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.8, step=0.1), | |
gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.9, step=0.1), | |
], | |
cache_examples=False, | |
) | |
load_button.click( | |
load_model_and_launch, | |
inputs=[repo_url_input, quant_type_input], | |
outputs=status_label, | |
) | |
iface.launch() | |
if __name__ == "__main__": | |
main() | |