bestie / .ipynb_checkpoints /main-checkpoint.py
Bryan Bimantaka (Monash University)
add cache
d774ace
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langchain_community.document_loaders import TextLoader
from huggingface_hub import InferenceClient
import transformers
from sentence_transformers import SentenceTransformer
from datasets import Dataset, Features, Value, Sequence
import pandas as pd
import faiss
import os
import torch
import gradio as gr
ST_MODEL = "LazarusNLP/all-indo-e5-small-v4"
BASE_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
DOMAIN_DATA_DIR = "./data"
SYS_MSG = """
Kamu adalah asisten dalam sebuah perusahaan penyedia listrik (PLN) yang membantu menjawab pertanyaan seputar 'sexual harassment' dalam Bahasa Indonesia.
Jawab dengan singkat menggunakan konteks untuk menjawab pertanyaan dalam Bahasa Indonesia.
"""
TOP_K = 1
from huggingface_hub import login
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
# Autentikasi secara manual menggunakan token
login(token=hf_token)
domain_data = [os.path.join(DOMAIN_DATA_DIR, f) for f in os.listdir(DOMAIN_DATA_DIR) if f.endswith('.txt')]
pages = []
for file in domain_data:
text_loader = TextLoader(file)
file_pages = text_loader.load()
pages.extend(file_pages)
from langchain.text_splitter import RecursiveCharacterTextSplitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=300,
chunk_overlap=64,
separators=["\n\n"]
)
documents = splitter.split_documents(pages)
content = [doc.page_content.strip() for doc in documents]
ST = SentenceTransformer(ST_MODEL)
embeddings = ST.encode(content)
features = Features({
'text': Value('string'),
'embeddings': Sequence(Value('float32'))
})
data = {'text': content, 'embeddings': [embedding.tolist() for embedding in embeddings]}
dataset = Dataset.from_dict(data, features=features)
dataset.add_faiss_index(column='embeddings')
def retrieve(query, top_k=3):
query_embedding = ST.encode([query])
scores, retrieved_examples = dataset.get_nearest_examples('embeddings', query_embedding, k=top_k)
return scores, retrieved_examples['text']
# use quantization to lower GPU usage
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
)
def format_prompt(prompt, retrieved_documents, k):
"""using the retrieved documents we will prompt the model to generate our responses"""
PROMPT = f"Pertanyaan:{prompt}\nKonteks:"
for idx in range(k) :
PROMPT+= f"{retrieved_documents[idx]}\n"
return PROMPT
def chat_function(message, history, max_new_tokens=256, temperature=0.6):
scores, retrieved_doc = retrieve(message, TOP_K)
formatted_prompt = format_prompt(message, retrieved_doc, TOP_K)
messages = [{"role":"system","content":SYS_MSG},
{"role":"user", "content":formatted_prompt}]
prompt = pipeline.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,)
terminators = [
pipeline.tokenizer.eos_token_id,
pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
outputs = pipeline(
prompt,
max_new_tokens = max_new_tokens,
eos_token_id = terminators,
do_sample = True,
temperature = temperature + 0.1,
top_p = 0.9,)
return outputs[0]["generated_text"][len(prompt):]
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# demo = gr.ChatInterface(
# respond,
# textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7),
# )
demo = gr.ChatInterface(
chat_function,
textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7),
chatbot=gr.Chatbot(height=400),
)
if __name__ == "__main__":
demo.launch(share=True)