File size: 5,850 Bytes
ce19127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from llama_index.llms import HuggingFaceInferenceAPI
from llama_index.llms import ChatMessage, MessageRole
from llama_index.prompts import ChatPromptTemplate
from llama_index import VectorStoreIndex, SimpleDirectoryReader, LLMPredictor, ServiceContext, StorageContext, load_index_from_storage
import gradio as gr
import sys
import logging
import torch
from huggingface_hub import InferenceClient
import tqdm as notebook_tqdm
import requests

def download_file(url, filename):
    """
    Download a file from the specified URL and save it locally under the given filename.
    """
    response = requests.get(url, stream=True)

    # Check if the request was successful
    if response.status_code == 200:
        with open(filename, 'wb') as file:
            for chunk in response.iter_content(chunk_size=1024):
                if chunk:  # filter out keep-alive new chunks
                    file.write(chunk)
        print(f"Download complete: {filename}")
    else:
        print(f"Error: Unable to download file. HTTP status code: {response.status_code}")

def generate(prompt, history, file_link, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,):
    mixtral = HuggingFaceInferenceAPI(
        model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
    )

    service_context = ServiceContext.from_defaults(
        llm=mixtral, embed_model="local:BAAI/bge-small-en-v1.5"
    )


    download = download_file(file_link,file_link.split("/")[-1])

    documents = SimpleDirectoryReader("/content").load_data()
    index = VectorStoreIndex.from_documents(documents,service_context=service_context)

    # Text QA Prompt
    chat_text_qa_msgs = [
        ChatMessage(
            role=MessageRole.SYSTEM,
            content=(
                "Always answer the question, even if the context isn't helpful."
            ),
        ),
        ChatMessage(
            role=MessageRole.USER,
            content=(
                "Context information is below.\n"
                "---------------------\n"
                "{context_str}\n"
                "---------------------\n"
                "Given the context information and not prior knowledge, "
                "answer the question: {query_str}\n"
            ),
        ),
    ]
    text_qa_template = ChatPromptTemplate(chat_text_qa_msgs)

    # Refine Prompt
    chat_refine_msgs = [
        ChatMessage(
            role=MessageRole.SYSTEM,
            content=(
                "Always answer the question, even if the context isn't helpful."
            ),
        ),
        ChatMessage(
            role=MessageRole.USER,
            content=(
                "We have the opportunity to refine the original answer "
                "(only if needed) with some more context below.\n"
                "------------\n"
                "{context_msg}\n"
                "------------\n"
                "Given the new context, refine the original answer to better "
                "answer the question: {query_str}. "
                "If the context isn't useful, output the original answer again.\n"
                "Original Answer: {existing_answer}"
            ),
        ),
    ]
    refine_template = ChatPromptTemplate(chat_refine_msgs)

    stream= index.as_query_engine(
          text_qa_template=text_qa_template, refine_template=refine_template, similarity_top_k=6
      ).query(prompt)
    print(str(stream))

    output=""

    for response in str(stream):
          output += response
          yield output
    return output

def upload_file(files):
    file_paths = [file.name for file in files]
    return file_paths

additional_inputs=[
    gr.Textbox(
        label="File Link",
        max_lines=1,
        interactive=True,
        value="https://arxiv.org/pdf/2401.10020.pdf"
    ),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=1024,
        minimum=0,
        maximum=2048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

examples=[["Explain the paper and describe its novelty", None, None, None, None, None, ],
          ["Can you write a short story about a time-traveling detective who solves historical mysteries?", None, None, None, None, None,],
          ["I'm trying to learn French. Can you provide some common phrases that would be useful for a beginner, along with their pronunciations?", None, None, None, None, None,],
          ["I have chicken, rice, and bell peppers in my kitchen. Can you suggest an easy recipe I can make with these ingredients?", None, None, None, None, None,],
          ["Can you explain how the QuickSort algorithm works and provide a Python implementation?", None, None, None, None, None,],
          ["What are some unique features of Rust that make it stand out compared to other systems programming languages like C++?", None, None, None, None, None,],
         ]

gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="RAG Demo",
    examples=examples,
    concurrency_limit=20,
).launch(show_api=False,debug=True)