|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline |
|
from langchain_community.document_loaders import TextLoader |
|
from huggingface_hub import InferenceClient |
|
import transformers |
|
from sentence_transformers import SentenceTransformer |
|
from datasets import Dataset, Features, Value, Sequence |
|
import pandas as pd |
|
import faiss |
|
import os |
|
import torch |
|
import gradio as gr |
|
|
|
ST_MODEL = "LazarusNLP/all-indo-e5-small-v4" |
|
BASE_MODEL = "meta-llama/Llama-3.1-8B-Instruct" |
|
DOMAIN_DATA_DIR = "./data" |
|
SYS_MSG = """ |
|
Kamu adalah asisten dalam sebuah perusahaan penyedia listrik (PLN) yang membantu menjawab pertanyaan seputar 'sexual harassment' dalam Bahasa Indonesia. |
|
Jawab dengan singkat menggunakan konteks untuk menjawab pertanyaan dalam Bahasa Indonesia. |
|
""" |
|
TOP_K = 1 |
|
|
|
from huggingface_hub import login |
|
|
|
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") |
|
|
|
|
|
login(token=hf_token) |
|
|
|
domain_data = [os.path.join(DOMAIN_DATA_DIR, f) for f in os.listdir(DOMAIN_DATA_DIR) if f.endswith('.txt')] |
|
pages = [] |
|
|
|
for file in domain_data: |
|
text_loader = TextLoader(file) |
|
file_pages = text_loader.load() |
|
pages.extend(file_pages) |
|
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=300, |
|
chunk_overlap=64, |
|
separators=["\n\n"] |
|
) |
|
|
|
documents = splitter.split_documents(pages) |
|
content = [doc.page_content.strip() for doc in documents] |
|
|
|
ST = SentenceTransformer(ST_MODEL) |
|
embeddings = ST.encode(content) |
|
|
|
features = Features({ |
|
'text': Value('string'), |
|
'embeddings': Sequence(Value('float32')) |
|
}) |
|
|
|
data = {'text': content, 'embeddings': [embedding.tolist() for embedding in embeddings]} |
|
dataset = Dataset.from_dict(data, features=features) |
|
|
|
dataset.add_faiss_index(column='embeddings') |
|
|
|
def retrieve(query, top_k=3): |
|
query_embedding = ST.encode([query]) |
|
scores, retrieved_examples = dataset.get_nearest_examples('embeddings', query_embedding, k=top_k) |
|
|
|
return scores, retrieved_examples['text'] |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
BASE_MODEL, |
|
torch_dtype=torch.bfloat16, |
|
device_map="auto", |
|
quantization_config=bnb_config, |
|
) |
|
|
|
def format_prompt(prompt, retrieved_documents, k): |
|
"""using the retrieved documents we will prompt the model to generate our responses""" |
|
PROMPT = f"Pertanyaan:{prompt}\nKonteks:" |
|
for idx in range(k) : |
|
PROMPT+= f"{retrieved_documents[idx]}\n" |
|
|
|
return PROMPT |
|
|
|
def chat_function(message, history, max_new_tokens=256, temperature=0.6): |
|
scores, retrieved_doc = retrieve(message, TOP_K) |
|
formatted_prompt = format_prompt(message, retrieved_doc, TOP_K) |
|
|
|
messages = [{"role":"system","content":SYS_MSG}, |
|
{"role":"user", "content":formatted_prompt}] |
|
prompt = pipeline.tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True,) |
|
terminators = [ |
|
pipeline.tokenizer.eos_token_id, |
|
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")] |
|
outputs = pipeline( |
|
prompt, |
|
max_new_tokens = max_new_tokens, |
|
eos_token_id = terminators, |
|
do_sample = True, |
|
temperature = temperature + 0.1, |
|
top_p = 0.9,) |
|
return outputs[0]["generated_text"][len(prompt):] |
|
|
|
""" |
|
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
chat_function, |
|
textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7), |
|
chatbot=gr.Chatbot(height=400), |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |