File size: 4,251 Bytes
02d6587
d774ace
 
02d6587
 
 
 
 
 
 
fe7c49a
02d6587
 
d774ace
02d6587
d774ace
02d6587
 
 
 
 
d774ace
 
 
 
 
 
 
 
 
 
 
 
 
02d6587
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d774ace
 
02d6587
d774ace
 
 
 
 
 
 
 
02d6587
d774ace
 
 
 
 
 
 
 
02d6587
d774ace
 
 
 
 
 
02d6587
d774ace
 
 
 
 
 
 
02d6587
d774ace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02d6587
d774ace
02d6587
d774ace
02d6587
 
 
d774ace
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline
from langchain_community.document_loaders import TextLoader
from huggingface_hub import InferenceClient 
import transformers
from sentence_transformers import SentenceTransformer
from datasets import Dataset, Features, Value, Sequence
import pandas as pd
import faiss
import os
import torch
import gradio as gr

ST_MODEL = "LazarusNLP/all-indo-e5-small-v4"
BASE_MODEL = "meta-llama/Llama-3.2-1B-Instruct"
DOMAIN_DATA_DIR = "./data"
CACHE_DIR = "./cache"
SYS_MSG = """
Kamu adalah asisten dalam sebuah perusahaan penyedia listrik (PLN) yang membantu menjawab pertanyaan seputar 'sexual harassment' dalam Bahasa Indonesia.
Jawab dengan singkat menggunakan konteks untuk menjawab pertanyaan dalam Bahasa Indonesia.
"""

# LOGIN HF Auth
from huggingface_hub import login

# Ambil token API dari environment variable (jika disimpan di secrets)
import os
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")

# Autentikasi secara manual menggunakan token
login(token=hf_token)

# ----------------------------------------------------------------------------------------------------------
# RAG PROCESS
TOP_K = 1
domain_data = [os.path.join(DOMAIN_DATA_DIR, f) for f in os.listdir(DOMAIN_DATA_DIR) if f.endswith('.txt')]
pages = []

for file in domain_data:
    text_loader = TextLoader(file)
    file_pages = text_loader.load()
    pages.extend(file_pages)

from langchain.text_splitter import RecursiveCharacterTextSplitter

splitter = RecursiveCharacterTextSplitter(
    chunk_size=300,
    chunk_overlap=64,
    separators=["\n\n"]
)

documents = splitter.split_documents(pages)
content = [doc.page_content.strip() for doc in documents]

ST = SentenceTransformer(ST_MODEL)
embeddings = ST.encode(content)

features = Features({
    'text': Value('string'),
    'embeddings': Sequence(Value('float32'))
})

data = {'text': content, 'embeddings': [embedding.tolist() for embedding in embeddings]}
dataset = Dataset.from_dict(data, features=features)

dataset.add_faiss_index(column='embeddings')

def retrieve(query, top_k=3):
    query_embedding = ST.encode([query])
    scores, retrieved_examples = dataset.get_nearest_examples('embeddings', query_embedding, k=top_k)

    return scores, retrieved_examples['text']

# END RAG
# ----------------------------------------------------------------------------------------------------------

# LLM
# use quantization to lower GPU usage
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    # device_map="auto",
    quantization_config=bnb_config,
    cache_dir=CACHE_DIR
)

def format_prompt(prompt, retrieved_documents, k):
    """using the retrieved documents we will prompt the model to generate our responses"""
    PROMPT = f"Pertanyaan:{prompt}\nKonteks:"
    for idx in range(k) :
        PROMPT+= f"{retrieved_documents[idx]}\n"
    return PROMPT

def chat_function(message, history, max_new_tokens=256, temperature=0.6):
    _, retrieved_doc = retrieve(message, TOP_K)
    formatted_prompt = format_prompt(message, retrieved_doc, TOP_K)
    
    messages = [{"role":"system","content":SYS_MSG},
                {"role":"user", "content":formatted_prompt}]
    prompt = pipeline.tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,)
    print(f"Prompt: {prompt}\n")
    terminators = [
        pipeline.tokenizer.eos_token_id,
        pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
    outputs = pipeline(
        prompt,
        max_new_tokens = max_new_tokens,
        eos_token_id = terminators,
        do_sample = True,
        temperature = temperature,
        top_p = 0.9,)
    return outputs[0]["generated_text"][len(prompt):]
    
demo = gr.ChatInterface(
    chat_function,
    textbox=gr.Textbox(placeholder="Enter message here", container=False, scale = 7),
    chatbot=gr.Chatbot(height=400),
)

if __name__ == "__main__":
    demo.launch(share=True)