bestie / app.py
Bryan Bimantaka (Monash University)
update app.py
d01b523
raw
history blame
3.51 kB
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langchain_community.document_loaders import TextLoader
from huggingface_hub import InferenceClient
import transformers
from sentence_transformers import SentenceTransformer
from datasets import Dataset, Features, Value, Sequence
import pandas as pd
import faiss
import os
import torch
import gradio as gr
ST_MODEL = "LazarusNLP/all-indo-e5-small-v4"
BASE_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
DOMAIN_DATA_DIR = "./data"
SYS_MSG = """
Kamu adalah asisten dalam sebuah perusahaan penyedia listrik (PLN) yang membantu menjawab pertanyaan seputar 'sexual harassment' dalam Bahasa Indonesia.
Jawab dengan singkat menggunakan konteks untuk menjawab pertanyaan dalam Bahasa Indonesia.
"""
TOP_K = 1
domain_data = [os.path.join(DOMAIN_DATA_DIR, f) for f in os.listdir(DOMAIN_DATA_DIR) if f.endswith('.txt')]
pages = []
for file in domain_data:
text_loader = TextLoader(file)
file_pages = text_loader.load()
pages.extend(file_pages)
from langchain.text_splitter import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=300,
chunk_overlap=64,
separators=["\n\n"]
)
documents = splitter.split_documents(pages)
content = [doc.page_content.strip() for doc in documents]
ST = SentenceTransformer(ST_MODEL)
embeddings = ST.encode(content)
features = Features({
'text': Value('string'),
'embeddings': Sequence(Value('float32'))
})
data = {'text': content, 'embeddings': [embedding.tolist() for embedding in embeddings]}
dataset = Dataset.from_dict(data, features=features)
dataset.add_faiss_index(column='embeddings')
def retrieve(query, top_k=3):
query_embedding = ST.encode([query])
scores, retrieved_examples = dataset.get_nearest_examples('embeddings', query_embedding, k=top_k)
return scores, retrieved_examples['text']
client = InferenceClient(BASE_MODEL)
def respond(
message,
history: list[tuple[str, str]],
max_tokens=512,
temperature=0.4,
top_p=0.9,
):
# Retrieve top 3 relevant documents based on the user's query
_, retrieved_docs = retrieve(message, top_k=TOP_K)
# Prepare the retrieved context
context = "\n".join([f"Dokumen {i+1}: {doc}" for i, doc in enumerate(retrieved_docs)])
print(f"Feed:\n{context}")
messages = [{"role": "system", "content": SYS_MSG}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
# messages.append({"role": "user", "content": message})
# Append the current user message along with the retrieved context
user_context = f"{message}\nKonteks:\n{context}"
messages.append({"role": "user", "content": user_context})
response = ""
for message in client.chat_completion(
messages,
max_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
print(f"Message:\n{message}\n\n")
token = message.choices[0].delta.content
response += token
yield response
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7),
)
if __name__ == "__main__":
demo.launch(share=True)