Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,88 +2,76 @@ import gradio as gr
|
|
2 |
import os
|
3 |
import faiss
|
4 |
import torch
|
5 |
-
|
|
|
|
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
-
import logging
|
8 |
-
|
9 |
-
# Set up logging
|
10 |
-
logging.basicConfig(level=logging.INFO)
|
11 |
|
12 |
# Hugging Face Credentials
|
13 |
-
HF_REPO = "Futuresony/future_ai_12_10_2024.gguf"
|
14 |
-
|
15 |
-
|
16 |
|
17 |
-
# Load FAISS
|
18 |
-
|
19 |
-
|
20 |
-
filename="asa_faiss.index",
|
21 |
-
repo_type="model",
|
22 |
-
token=HF_TOKEN
|
23 |
-
)
|
24 |
-
faiss_index = faiss.read_index(faiss_index_path)
|
25 |
|
26 |
-
# Load
|
27 |
-
|
28 |
|
29 |
-
# Hugging Face
|
30 |
-
client = InferenceClient(
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
distances, indices = faiss_index.search(query_embedding, top_k)
|
40 |
|
41 |
-
|
42 |
-
retrieved_context = "\n".join([f"Context {i+1}: Retrieved data for index {idx}" for i, idx in enumerate(indices[0])])
|
43 |
-
return retrieved_context
|
44 |
|
45 |
-
# Function to format input in Alpaca style
|
46 |
def format_alpaca_prompt(user_input, system_prompt, history):
|
47 |
"""Formats input in Alpaca/LLaMA style"""
|
48 |
-
retrieved_context = retrieve_context(user_input) # Retrieve past knowledge
|
49 |
history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])
|
50 |
-
|
51 |
prompt = f"""{system_prompt}
|
52 |
{history_str}
|
53 |
|
54 |
### Instruction:
|
55 |
{user_input}
|
56 |
|
57 |
-
### Retrieved Context:
|
58 |
-
{retrieved_context}
|
59 |
-
|
60 |
### Response:
|
61 |
"""
|
62 |
return prompt
|
63 |
|
64 |
-
# Chatbot response function
|
65 |
def respond(message, history, system_message, max_tokens, temperature, top_p):
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
response = client.text_generation(
|
69 |
-
|
70 |
max_new_tokens=max_tokens,
|
71 |
temperature=temperature,
|
72 |
top_p=top_p,
|
73 |
)
|
74 |
|
75 |
-
# Extract only the response
|
76 |
cleaned_response = response.split("### Response:")[-1].strip()
|
77 |
|
78 |
-
history.append((message, cleaned_response)) # Update
|
79 |
|
80 |
-
yield cleaned_response # Output only the answer
|
81 |
|
82 |
-
# Gradio Chat Interface
|
83 |
demo = gr.ChatInterface(
|
84 |
respond,
|
85 |
additional_inputs=[
|
86 |
-
gr.Textbox(value="You are a
|
87 |
gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
|
88 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
|
89 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
|
|
|
2 |
import os
|
3 |
import faiss
|
4 |
import torch
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
from huggingface_hub import hf_hub_download, InferenceClient
|
8 |
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Hugging Face Credentials
|
11 |
+
HF_REPO = "Futuresony/future_ai_12_10_2024.gguf"
|
12 |
+
HF_FAISS_FILE = "asa_faiss.index"
|
13 |
+
api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
|
14 |
|
15 |
+
# Load the FAISS index from Hugging Face
|
16 |
+
faiss_local_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FAISS_FILE, repo_type="model", token=api_token)
|
17 |
+
index = faiss.read_index(faiss_local_path)
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# Load the same embedding model used for FAISS
|
20 |
+
embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
21 |
|
22 |
+
# Hugging Face model client
|
23 |
+
client = InferenceClient(model=HF_REPO, token=api_token)
|
24 |
+
|
25 |
+
def retrieve_relevant_context(query, top_k=3):
|
26 |
+
"""Retrieve the most relevant text chunks from FAISS."""
|
27 |
+
query_embedding = embedding_model.encode([query]).astype(np.float32)
|
28 |
+
distances, indices = index.search(query_embedding, top_k)
|
29 |
|
30 |
+
retrieved_texts = []
|
31 |
+
for idx in indices[0]: # Get the closest matches
|
32 |
+
if idx != -1: # Valid match
|
33 |
+
retrieved_texts.append(f"Relevant info: {idx}")
|
|
|
34 |
|
35 |
+
return "\n".join(retrieved_texts) if retrieved_texts else "No relevant info found."
|
|
|
|
|
36 |
|
|
|
37 |
def format_alpaca_prompt(user_input, system_prompt, history):
|
38 |
"""Formats input in Alpaca/LLaMA style"""
|
|
|
39 |
history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])
|
|
|
40 |
prompt = f"""{system_prompt}
|
41 |
{history_str}
|
42 |
|
43 |
### Instruction:
|
44 |
{user_input}
|
45 |
|
|
|
|
|
|
|
46 |
### Response:
|
47 |
"""
|
48 |
return prompt
|
49 |
|
|
|
50 |
def respond(message, history, system_message, max_tokens, temperature, top_p):
|
51 |
+
# 🔹 Retrieve relevant info from FAISS
|
52 |
+
retrieved_context = retrieve_relevant_context(message)
|
53 |
+
|
54 |
+
# 🔹 Include retrieved context in the prompt
|
55 |
+
full_prompt = f"{retrieved_context}\n\n{format_alpaca_prompt(message, system_message, history)}"
|
56 |
|
57 |
response = client.text_generation(
|
58 |
+
full_prompt,
|
59 |
max_new_tokens=max_tokens,
|
60 |
temperature=temperature,
|
61 |
top_p=top_p,
|
62 |
)
|
63 |
|
64 |
+
# ✅ Extract only the response
|
65 |
cleaned_response = response.split("### Response:")[-1].strip()
|
66 |
|
67 |
+
history.append((message, cleaned_response)) # ✅ Update history with the new message and response
|
68 |
|
69 |
+
yield cleaned_response # ✅ Output only the answer
|
70 |
|
|
|
71 |
demo = gr.ChatInterface(
|
72 |
respond,
|
73 |
additional_inputs=[
|
74 |
+
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
75 |
gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
|
76 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
|
77 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
|