Futuresony commited on
Commit
c4efeea
Β·
verified Β·
1 Parent(s): ccf2073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -40
app.py CHANGED
@@ -1,58 +1,57 @@
1
  import gradio as gr
2
  import os
3
- import faiss
4
- import torch
5
  import json
 
6
  import numpy as np
7
- from huggingface_hub import hf_hub_download, InferenceClient
8
  from sentence_transformers import SentenceTransformer
 
9
 
10
- # Hugging Face Credentials
11
  HF_REPO = "Futuresony/future_ai_12_10_2024.gguf"
12
- HF_FAISS_FILE = "asa_faiss.index"
13
- api_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')
14
 
15
- # Load the FAISS index from Hugging Face
16
- faiss_local_path = hf_hub_download(repo_id=HF_REPO, filename=HF_FAISS_FILE, repo_type="model", token=api_token)
17
- index = faiss.read_index(faiss_local_path)
18
 
19
- # Load the same embedding model used for FAISS
20
- embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
21
 
22
- # Hugging Face model client
23
- client = InferenceClient(model=HF_REPO, token=api_token)
 
24
 
25
- def retrieve_relevant_context(query, top_k=3):
26
- """Retrieve the most relevant text chunks from FAISS."""
27
- query_embedding = embedding_model.encode([query]).astype(np.float32)
28
- distances, indices = index.search(query_embedding, top_k)
 
 
 
29
 
30
  retrieved_texts = []
31
- for idx in indices[0]: # Get the closest matches
32
- if idx != -1: # Valid match
33
- retrieved_texts.append(f"Relevant info: {idx}")
34
 
35
- return "\n".join(retrieved_texts) if retrieved_texts else "No relevant info found."
36
 
37
- def format_alpaca_prompt(user_input, system_prompt, history):
38
- """Formats input in Alpaca/LLaMA style"""
39
- history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])
40
- prompt = f"""{system_prompt}
41
- {history_str}
42
 
43
- ### Instruction:
44
- {user_input}
 
 
 
 
45
 
46
- ### Response:
47
- """
48
- return prompt
49
 
 
50
  def respond(message, history, system_message, max_tokens, temperature, top_p):
51
- # πŸ”Ή Retrieve relevant info from FAISS
52
- retrieved_context = retrieve_relevant_context(message)
53
-
54
- # πŸ”Ή Include retrieved context in the prompt
55
- full_prompt = f"{retrieved_context}\n\n{format_alpaca_prompt(message, system_message, history)}"
56
 
57
  response = client.text_generation(
58
  full_prompt,
@@ -61,17 +60,18 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
61
  top_p=top_p,
62
  )
63
 
64
- # βœ… Extract only the response
65
  cleaned_response = response.split("### Response:")[-1].strip()
66
 
67
- history.append((message, cleaned_response)) # βœ… Update history with the new message and response
68
 
69
- yield cleaned_response # βœ… Output only the answer
70
 
 
71
  demo = gr.ChatInterface(
72
  respond,
73
  additional_inputs=[
74
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
75
  gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
76
  gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
77
  gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
 
1
  import gradio as gr
2
  import os
 
 
3
  import json
4
+ import faiss
5
  import numpy as np
6
+ import torch
7
  from sentence_transformers import SentenceTransformer
8
+ from huggingface_hub import InferenceClient, hf_hub_download
9
 
10
+ # πŸ”Ή Hugging Face Credentials
11
  HF_REPO = "Futuresony/future_ai_12_10_2024.gguf"
12
+ HF_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN') # Store your token as an environment variable for security
 
13
 
14
+ # πŸ”Ή FAISS Index Path
15
+ FAISS_PATH = "asa_faiss.index"
 
16
 
17
+ # πŸ”Ή Load Sentence Transformer for Embeddings
18
+ embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
19
 
20
+ # πŸ”Ή Load FAISS Index from Hugging Face
21
+ faiss_local_path = hf_hub_download(HF_REPO, "asa_faiss.index", token=HF_TOKEN)
22
+ faiss_index = faiss.read_index(faiss_local_path)
23
 
24
+ # πŸ”Ή Initialize Hugging Face Model Client
25
+ client = InferenceClient(model=HF_REPO, token=HF_TOKEN)
26
+
27
+ # πŸ”Ή Retrieve Relevant FAISS Context
28
+ def retrieve_relevant_context(user_query, top_k=3):
29
+ query_embedding = embedder.encode([user_query], convert_to_tensor=True).cpu().numpy()
30
+ distances, indices = faiss_index.search(query_embedding, top_k)
31
 
32
  retrieved_texts = []
33
+ for idx in indices[0]: # Extract top_k results
34
+ if idx != -1: # Ensure valid index
35
+ retrieved_texts.append(f"Example: {idx} β†’ {idx}") # Customize how retrieved data appears
36
 
37
+ return "\n".join(retrieved_texts) if retrieved_texts else "No relevant data found."
38
 
39
+ # πŸ”Ή Format Model Prompt with FAISS Guidance
40
+ def format_prompt(user_input, system_prompt, history):
41
+ retrieved_context = retrieve_relevant_context(user_input)
 
 
42
 
43
+ faiss_instruction = (
44
+ "Use the following example responses as a guide for formatting and writing style:\n"
45
+ f"{retrieved_context}\n\n"
46
+ "### Instruction:\n"
47
+ f"{user_input}\n\n### Response:"
48
+ )
49
 
50
+ return faiss_instruction
 
 
51
 
52
+ # πŸ”Ή Chatbot Response Function
53
  def respond(message, history, system_message, max_tokens, temperature, top_p):
54
+ full_prompt = format_prompt(message, system_message, history)
 
 
 
 
55
 
56
  response = client.text_generation(
57
  full_prompt,
 
60
  top_p=top_p,
61
  )
62
 
63
+ # βœ… Extract only model-generated response
64
  cleaned_response = response.split("### Response:")[-1].strip()
65
 
66
+ history.append((message, cleaned_response)) # βœ… Update chat history
67
 
68
+ yield cleaned_response # βœ… Output the response
69
 
70
+ # πŸ”Ή Gradio Chat Interface
71
  demo = gr.ChatInterface(
72
  respond,
73
  additional_inputs=[
74
+ gr.Textbox(value="You are a helpful AI trained to follow FAISS-based writing styles.", label="System message"),
75
  gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
76
  gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
77
  gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),