Futuresony commited on
Commit
ccf2073
·
verified ·
1 Parent(s): a23d611

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -45
app.py CHANGED
@@ -2,88 +2,76 @@ import gradio as gr
2
  import os
3
  import faiss
4
  import torch
5
- from huggingface_hub import InferenceClient, hf_hub_download
 
 
6
  from sentence_transformers import SentenceTransformer
7
- import logging
8
-
9
- # Set up logging
10
- logging.basicConfig(level=logging.INFO)
11
 
12
  # Hugging Face Credentials
13
- HF_REPO = "Futuresony/future_ai_12_10_2024.gguf" # Your model repo
14
- HF_FAISS_REPO = "Futuresony/future_ai_12_10_2024.gguf" # Your FAISS repo
15
- HF_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') # API token from env
16
 
17
- # Load FAISS Index
18
- faiss_index_path = hf_hub_download(
19
- repo_id=HF_FAISS_REPO,
20
- filename="asa_faiss.index",
21
- repo_type="model",
22
- token=HF_TOKEN
23
- )
24
- faiss_index = faiss.read_index(faiss_index_path)
25
 
26
- # Load Sentence Transformer for embedding queries
27
- embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
28
 
29
- # Hugging Face Model Client
30
- client = InferenceClient(
31
- model=HF_REPO,
32
- token=HF_TOKEN
33
- )
 
 
34
 
35
- # Function to retrieve relevant context from FAISS
36
- def retrieve_context(query, top_k=3):
37
- """Retrieve relevant past knowledge using FAISS"""
38
- query_embedding = embed_model.encode([query], convert_to_tensor=True).cpu().numpy()
39
- distances, indices = faiss_index.search(query_embedding, top_k)
40
 
41
- # Convert indices to retrieved text (simulate as FAISS only returns IDs)
42
- retrieved_context = "\n".join([f"Context {i+1}: Retrieved data for index {idx}" for i, idx in enumerate(indices[0])])
43
- return retrieved_context
44
 
45
- # Function to format input in Alpaca style
46
  def format_alpaca_prompt(user_input, system_prompt, history):
47
  """Formats input in Alpaca/LLaMA style"""
48
- retrieved_context = retrieve_context(user_input) # Retrieve past knowledge
49
  history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])
50
-
51
  prompt = f"""{system_prompt}
52
  {history_str}
53
 
54
  ### Instruction:
55
  {user_input}
56
 
57
- ### Retrieved Context:
58
- {retrieved_context}
59
-
60
  ### Response:
61
  """
62
  return prompt
63
 
64
- # Chatbot response function
65
  def respond(message, history, system_message, max_tokens, temperature, top_p):
66
- formatted_prompt = format_alpaca_prompt(message, system_message, history)
 
 
 
 
67
 
68
  response = client.text_generation(
69
- formatted_prompt,
70
  max_new_tokens=max_tokens,
71
  temperature=temperature,
72
  top_p=top_p,
73
  )
74
 
75
- # Extract only the response
76
  cleaned_response = response.split("### Response:")[-1].strip()
77
 
78
- history.append((message, cleaned_response)) # Update chat history
79
 
80
- yield cleaned_response # Output only the answer
81
 
82
- # Gradio Chat Interface
83
  demo = gr.ChatInterface(
84
  respond,
85
  additional_inputs=[
86
- gr.Textbox(value="You are a helpful AI.", label="System message"),
87
  gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
88
  gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
89
  gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
 
2
  import os
3
  import faiss
4
  import torch
5
+ import json
6
+ import numpy as np
7
+ from huggingface_hub import hf_hub_download, InferenceClient
8
  from sentence_transformers import SentenceTransformer
 
 
 
 
9
 
10
  # Hugging Face Credentials
11
+ HF_REPO = "Futuresony/future_ai_12_10_2024.gguf"
12
+ HF_FAISS_FILE = "asa_faiss.index"
13
+ api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
14
 
15
+ # Load the FAISS index from Hugging Face
16
+ faiss_local_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FAISS_FILE, repo_type="model", token=api_token)
17
+ index = faiss.read_index(faiss_local_path)
 
 
 
 
 
18
 
19
+ # Load the same embedding model used for FAISS
20
+ embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
21
 
22
+ # Hugging Face model client
23
+ client = InferenceClient(model=HF_REPO, token=api_token)
24
+
25
+ def retrieve_relevant_context(query, top_k=3):
26
+ """Retrieve the most relevant text chunks from FAISS."""
27
+ query_embedding = embedding_model.encode([query]).astype(np.float32)
28
+ distances, indices = index.search(query_embedding, top_k)
29
 
30
+ retrieved_texts = []
31
+ for idx in indices[0]: # Get the closest matches
32
+ if idx != -1: # Valid match
33
+ retrieved_texts.append(f"Relevant info: {idx}")
 
34
 
35
+ return "\n".join(retrieved_texts) if retrieved_texts else "No relevant info found."
 
 
36
 
 
37
  def format_alpaca_prompt(user_input, system_prompt, history):
38
  """Formats input in Alpaca/LLaMA style"""
 
39
  history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])
 
40
  prompt = f"""{system_prompt}
41
  {history_str}
42
 
43
  ### Instruction:
44
  {user_input}
45
 
 
 
 
46
  ### Response:
47
  """
48
  return prompt
49
 
 
50
  def respond(message, history, system_message, max_tokens, temperature, top_p):
51
+ # 🔹 Retrieve relevant info from FAISS
52
+ retrieved_context = retrieve_relevant_context(message)
53
+
54
+ # 🔹 Include retrieved context in the prompt
55
+ full_prompt = f"{retrieved_context}\n\n{format_alpaca_prompt(message, system_message, history)}"
56
 
57
  response = client.text_generation(
58
+ full_prompt,
59
  max_new_tokens=max_tokens,
60
  temperature=temperature,
61
  top_p=top_p,
62
  )
63
 
64
+ # Extract only the response
65
  cleaned_response = response.split("### Response:")[-1].strip()
66
 
67
+ history.append((message, cleaned_response)) # Update history with the new message and response
68
 
69
+ yield cleaned_response # Output only the answer
70
 
 
71
  demo = gr.ChatInterface(
72
  respond,
73
  additional_inputs=[
74
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
75
  gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
76
  gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
77
  gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),