File size: 3,011 Bytes
f0e607c
416c352
bf2110c
 
 
 
f728e8f
8bb7427
416c352
 
 
bf2110c
 
 
 
416c352
bf2110c
 
 
 
 
 
 
 
 
 
 
 
 
416c352
bf2110c
 
416c352
f0e607c
bf2110c
 
 
 
 
 
 
 
 
 
 
f0e607c
 
bf2110c
f0e607c
bf2110c
f0e607c
 
 
 
 
 
bf2110c
 
 
f0e607c
 
 
 
bf2110c
be18c9a
f0e607c
 
be18c9a
 
 
 
 
 
f0e607c
bf2110c
f0e607c
be18c9a
bf2110c
be18c9a
bf2110c
f0e607c
bf2110c
f0e607c
 
 
bf2110c
f0e607c
cdd8645
f0e607c
 
 
 
 
bf2110c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import os
import faiss
import torch
from huggingface_hub import InferenceClient, hf_hub_download
from sentence_transformers import SentenceTransformer
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)

# Hugging Face Credentials
HF_REPO = "Futuresony/future_ai_12_10_2024.gguf"  # Your model repo
HF_FAISS_REPO = "Futuresony/future_ai_12_10_2024.gguf"  # Your FAISS repo
HF_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')  # API token from env

# Load FAISS Index
faiss_index_path = hf_hub_download(
    repo_id=HF_FAISS_REPO, 
    filename="asa_faiss.index",
    repo_type="model", 
    token=HF_TOKEN
)
faiss_index = faiss.read_index(faiss_index_path)

# Load Sentence Transformer for embedding queries
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# Hugging Face Model Client
client = InferenceClient(
    model=HF_REPO,
    token=HF_TOKEN
)

# Function to retrieve relevant context from FAISS
def retrieve_context(query, top_k=3):
    """Retrieve relevant past knowledge using FAISS"""
    query_embedding = embed_model.encode([query], convert_to_tensor=True).cpu().numpy()
    distances, indices = faiss_index.search(query_embedding, top_k)

    # Convert indices to retrieved text (simulate as FAISS only returns IDs)
    retrieved_context = "\n".join([f"Context {i+1}: Retrieved data for index {idx}" for i, idx in enumerate(indices[0])])
    return retrieved_context

# Function to format input in Alpaca style
def format_alpaca_prompt(user_input, system_prompt, history):
    """Formats input in Alpaca/LLaMA style"""
    retrieved_context = retrieve_context(user_input)  # Retrieve past knowledge
    history_str = "\n".join([f"### Instruction:\n{h[0]}\n### Response:\n{h[1]}" for h in history])

    prompt = f"""{system_prompt}
{history_str}

### Instruction:
{user_input}

### Retrieved Context:
{retrieved_context}

### Response:
"""
    return prompt

# Chatbot response function
def respond(message, history, system_message, max_tokens, temperature, top_p):
    formatted_prompt = format_alpaca_prompt(message, system_message, history)

    response = client.text_generation(
        formatted_prompt,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
    )

    # Extract only the response
    cleaned_response = response.split("### Response:")[-1].strip()
    
    history.append((message, cleaned_response))  # Update chat history
    
    yield cleaned_response  # Output only the answer

# Gradio Chat Interface
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful AI.", label="System message"),
        gr.Slider(minimum=1, maximum=250, value=128, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.9, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.99, step=0.01, label="Top-p (nucleus sampling)"),
    ],
)

if __name__ == "__main__":
    demo.launch()