|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline |
|
from langchain_community.document_loaders import TextLoader |
|
from huggingface_hub import InferenceClient |
|
import transformers |
|
from sentence_transformers import SentenceTransformer |
|
from datasets import Dataset, Features, Value, Sequence |
|
import pandas as pd |
|
import faiss |
|
import os |
|
import torch |
|
import gradio as gr |
|
|
|
ST_MODEL = "LazarusNLP/all-indo-e5-small-v4" |
|
BASE_MODEL = "meta-llama/Meta-Llama-3-8B-Instruct" |
|
DOMAIN_DATA_DIR = "./data" |
|
SYS_MSG = """ |
|
Kamu adalah asisten dalam sebuah perusahaan penyedia listrik (PLN) yang membantu menjawab pertanyaan seputar 'sexual harassment' dalam Bahasa Indonesia. |
|
Jawab dengan singkat menggunakan konteks untuk menjawab pertanyaan dalam Bahasa Indonesia. |
|
""" |
|
TOP_K = 1 |
|
|
|
domain_data = [os.path.join(DOMAIN_DATA_DIR, f) for f in os.listdir(DOMAIN_DATA_DIR) if f.endswith('.txt')] |
|
pages = [] |
|
|
|
for file in domain_data: |
|
text_loader = TextLoader(file) |
|
file_pages = text_loader.load() |
|
pages.extend(file_pages) |
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=300, |
|
chunk_overlap=64, |
|
separators=["\n\n"] |
|
) |
|
|
|
documents = splitter.split_documents(pages) |
|
content = [doc.page_content.strip() for doc in documents] |
|
|
|
ST = SentenceTransformer(ST_MODEL) |
|
embeddings = ST.encode(content) |
|
|
|
features = Features({ |
|
'text': Value('string'), |
|
'embeddings': Sequence(Value('float32')) |
|
}) |
|
|
|
data = {'text': content, 'embeddings': [embedding.tolist() for embedding in embeddings]} |
|
dataset = Dataset.from_dict(data, features=features) |
|
|
|
dataset.add_faiss_index(column='embeddings') |
|
|
|
def retrieve(query, top_k=3): |
|
query_embedding = ST.encode([query]) |
|
scores, retrieved_examples = dataset.get_nearest_examples('embeddings', query_embedding, k=top_k) |
|
|
|
return scores, retrieved_examples['text'] |
|
|
|
client = InferenceClient(BASE_MODEL) |
|
|
|
def respond( |
|
message, |
|
history: list[tuple[str, str]], |
|
max_tokens=512, |
|
temperature=0.4, |
|
top_p=0.9, |
|
): |
|
|
|
_, retrieved_docs = retrieve(message, top_k=TOP_K) |
|
|
|
|
|
context = "\n".join([f"Dokumen {i+1}: {doc}" for i, doc in enumerate(retrieved_docs)]) |
|
|
|
print(f"Feed:\n{context}") |
|
|
|
messages = [{"role": "system", "content": SYS_MSG}] |
|
|
|
for val in history: |
|
if val[0]: |
|
messages.append({"role": "user", "content": val[0]}) |
|
if val[1]: |
|
messages.append({"role": "assistant", "content": val[1]}) |
|
|
|
|
|
|
|
|
|
user_context = f"{message}\nKonteks:\n{context}" |
|
messages.append({"role": "user", "content": user_context}) |
|
|
|
response = "" |
|
|
|
for message in client.chat_completion( |
|
messages, |
|
max_tokens=max_tokens, |
|
stream=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
print(f"Message:\n{message}\n\n") |
|
token = message.choices[0].delta.content |
|
|
|
response += token |
|
yield response |
|
|
|
""" |
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
""" |
|
demo = gr.ChatInterface( |
|
respond, |
|
textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7), |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |