File size: 5,850 Bytes
ce19127 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate
from llama_index import VectorStoreIndex, SimpleDirectoryReader, LLMPredictor, ServiceContext, StorageContext, load_index_from_storage
import gradio as gr
import sys
import logging
import torch
from huggingface_hub import InferenceClient
import tqdm as notebook_tqdm
import requests
def download_file(url, filename):
"""
Download a file from the specified URL and save it locally under the given filename.
"""
response = requests.get(url, stream=True)
# Check if the request was successful
if response.status_code == 200:
with open(filename, 'wb') as file:
for chunk in response.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
file.write(chunk)
print(f"Download complete: {filename}")
else:
print(f"Error: Unable to download file. HTTP status code: {response.status_code}")
def generate(prompt, history, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
mixtral = HuggingFaceInferenceAPI(
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
service_context = ServiceContext.from_defaults(
llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5"
)
download = download_file(file_link,file_link.split("/")[-1])
documents = SimpleDirectoryReader("/content").load_data()
index = VectorStoreIndex.from_documents(documents,service_context=service_context)
# Text QA Prompt
chat_text_qa_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"Always answer the question, even if the context isn't helpful."
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"Context information is below.\n"
"---------------------\n"
"{context_str}\n"
"---------------------\n"
"Given the context information and not prior knowledge, "
"answer the question: {query_str}\n"
),
),
]
text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)
# Refine Prompt
chat_refine_msgs = [
ChatMessage(
role=MessageRole.SYSTEM,
content=(
"Always answer the question, even if the context isn't helpful."
),
),
ChatMessage(
role=MessageRole.USER,
content=(
"We have the opportunity to refine the original answer "
"(only if needed) with some more context below.\n"
"------------\n"
"{context_msg}\n"
"------------\n"
"Given the new context, refine the original answer to better "
"answer the question: {query_str}. "
"If the context isn't useful, output the original answer again.\n"
"Original Answer: {existing_answer}"
),
),
]
refine_template = ChatPromptTemplate(chat_refine_msgs)
stream= index.as_query_engine(
text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6
).query(prompt)
print(str(stream))
output=""
for response in str(stream):
output += response
yield output
return output
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
additional_inputs=[
gr.Textbox(
label="File Link",
max_lines=1,
interactive=True,
value="https://arxiv.org/pdf/2401.10020.pdf"
),
gr.Slider(
label="Temperature",
value=0.9,
minimum=0.0,
maximum=1.0,
step=0.05,
interactive=True,
info="Higher values produce more diverse outputs",
),
gr.Slider(
label="Max new tokens",
value=1024,
minimum=0,
maximum=2048,
step=64,
interactive=True,
info="The maximum numbers of new tokens",
),
gr.Slider(
label="Top-p (nucleus sampling)",
value=0.90,
minimum=0.0,
maximum=1,
step=0.05,
interactive=True,
info="Higher values sample more low-probability tokens",
),
gr.Slider(
label="Repetition penalty",
value=1.2,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
info="Penalize repeated tokens",
)
]
examples=[["Explain the paper and describe its novelty", None, None, None, None, None, ],
["Can you write a short story about a time-traveling detective who solves historical mysteries?", None, None, None, None, None,],
["I'm trying to learn French. Can you provide some common phrases that would be useful for a beginner, along with their pronunciations?", None, None, None, None, None,],
["I have chicken, rice, and bell peppers in my kitchen. Can you suggest an easy recipe I can make with these ingredients?", None, None, None, None, None,],
["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None,],
["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None,],
]
gr.ChatInterface(
fn=generate,
chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
additional_inputs=additional_inputs,
title="RAG Demo",
examples=examples,
concurrency_limit=20,
).launch(show_api=False,debug=True) |