DreadPoor's picture
Update app.py
3605ca4 verified
raw
history blame
6.4 kB
import gradio as gr
from llama_cpp import Llama
from huggingface_hub import hf_hub_download, HfApi
import os
import sys
import time
import requests
from tqdm import tqdm # For progress bars
MODEL_PATH = "./"
llm = None
api = HfApi()
DEFAULT_SYSTEM_PROMPT = "You are Doll, a smart and capable AI; A silly, obliging and affable slave, dedicated to serving and caring for your master."
def download_file(url, local_filename):
"""Downloads a file from a URL with a progress bar."""
try:
with requests.get(url, stream=True) as r:
r.raise_for_status()
total_length = int(r.headers.get("content-length"))
with open(local_filename, "wb") as f:
with tqdm(total=total_length, unit="B", unit_scale=True, desc=local_filename) as pbar:
for chunk in r.iter_content(chunk_size=8192):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
return True
except Exception as e:
print(f"Error downloading {url}: {e}")
return False
def find_quantized_model_url(repo_url, quant_type="Q4_K_M"):
"""
Finds the URL of a specific quantized GGUF model file within a Hugging Face repository.
"""
try:
repo_id = repo_url.replace("https://huggingface.co/", "")
files = api.list_repo_files(repo_id=repo_id, repo_type="model")
for file_info in files:
if file_info.name.endswith(".gguf") and quant_type.lower() in file_info.name.lower():
model_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_info.name}"
print(f"Found quantized model URL: {model_url}")
return model_url
print(f"Quantized model with type {quant_type} not found in repository {repo_url}")
return None
except Exception as e:
print(f"Error finding quantized model: {e}")
return None
def load_model(repo_url=None, quant_type="Q4_K_M"):
"""Loads the Llama model, downloading the specified quantized version from a repository."""
global llm
global MODEL_PATH
try:
if repo_url:
model_url = find_quantized_model_url(repo_url, quant_type)
if model_url is None:
return f"Quantized model ({quant_type}) not found in the repository."
print(f"Downloading model from {model_url}...")
downloaded_model_name = os.path.basename(model_url)
download_success = download_file(model_url, downloaded_model_name)
if not download_success:
return "Model download failed."
model_path = downloaded_model_name
else:
model_path = MODEL_PATH + MODEL_FILENAME
if not os.path.exists(model_path):
if not repo_url: # only try to download if a repo_url was not provided
hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILENAME,
repo_type="model",
local_dir=".",
)
if not os.path.exists(model_path): # check again after attempting download
return f"Model file not found at {model_path}."
print(f"Loading model from {model_path}...")
llm = Llama(
model_path=model_path,
n_ctx=4096,
n_threads=2,
n_threads_batch=2,
verbose=False,
)
print("Model loaded successfully.")
return "Model loaded successfully."
except Exception as e:
error_message = f"Error loading model: {e}"
print(error_message)
llm = None
return error_message
def generate_response(message, history, system_prompt=DEFAULT_SYSTEM_PROMPT, temperature=0.7, top_p=0.9):
"""Generates a response from the Llama model."""
if llm is None:
yield "Model failed to load. Please check the console for error messages."
return
messages = [{"role": "system", "content": system_prompt}]
for human, assistant in history:
messages.append({"role": "user", "content": human})
messages.append({"role": "assistant", "content": assistant})
messages.append({"role": "user", "content": message})
prompt = "".join([f"{m['role'].capitalize()}: {m['content']}\n" for m in messages])
try:
for chunk in llm.create_completion(
prompt,
max_tokens=1024,
echo=False,
temperature=temperature,
top_p=top_p,
stream=True,
):
text = chunk["choices"][0]["text"]
yield text
except Exception as e:
error_message = f"Error during inference: {e}"
print(error_message)
yield error_message
def chat(message, history, system_prompt, temperature, top_p):
"""Wrapper function for the chat interface."""
return generate_response(message, history, system_prompt, temperature, top_p)
def main():
"""Main function to load the model and launch the Gradio interface."""
def load_model_and_launch(repo_url, quant_type):
model_load_message = load_model(repo_url, quant_type)
return model_load_message
with gr.Blocks() as iface:
gr.Markdown("## llama.cpp Chat")
status_label = gr.Label(label="Model Loading Status")
repo_url_input = gr.Textbox(label="Repository URL", placeholder="Enter repository URL")
quant_type_input = gr.Dropdown(
label="Quantization Type",
choices=["Q4_K_M", "Q6", "Q4_K_S"],
value="Q4_K_M",
)
load_button = gr.Button("Load Model")
chat_interface = gr.ChatInterface(
fn=chat,
description="Test a GGUF model. Chats aren't persistent.",
additional_inputs=[
gr.Textbox(label="System Prompt", value=DEFAULT_SYSTEM_PROMPT, lines=3),
gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.8, step=0.1),
gr.Slider(label="Top P", minimum=0.1, maximum=1.0, value=0.9, step=0.1),
],
cache_examples=False,
)
load_button.click(
load_model_and_launch,
inputs=[repo_url_input, quant_type_input],
outputs=status_label,
)
iface.launch()
if __name__ == "__main__":
main()