File size: 3,526 Bytes
a1918ea
 
 
 
 
 
db6559c
a1918ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import torch
from transformers import pipeline

model_name = "eljanmahammadli/AzLlama-152M-Alpaca"
model = pipeline("text-generation", model=model_name, torch_dtype=torch.float16)
logo_path = "AzLlama-logo.webp"


def get_prompt(question):
    base_instruction = "Aşağıda tapşırığı təsvir edən təlimat və əlavə kontekst təmin edən giriş verilmiştir. Sorğunu uyğun şəkildə tamamlayan cavab yazın."
    prompt = f"""{base_instruction}

### Təlimat:
{question}

### Cavab:
"""
    return prompt


def get_answer(llm_output):
    return llm_output.split("### Cavab:")[1].strip()


def answer_question(history, temperature, top_p, repetition_penalty, top_k, question):
    model_params = {
        "temperature": temperature,
        "top_p": top_p,
        "repetition_penalty": repetition_penalty,
        "top_k": top_k,
        "max_length": 512,  # Adjust based on your needs
        "do_sample": True,
    }
    prompt = get_prompt(question)
    llm_output = model(prompt, **model_params)[0]
    answer = get_answer(llm_output["generated_text"])
    divider = "\n\n" if history else ""
    print(answer)
    new_history = history + divider + f"USER: {question}\nASSISTANT: {answer}\n"
    return new_history, ""  # Return updated history and clear the question input


def send_action(_=None):
    send_button.click()


with gr.Blocks() as app:
    gr.Markdown("# AzLlama-150M Chatbot\n\n")

    with gr.Row():
        with gr.Column(scale=0.2, min_width=200):
            gr.Markdown("### Model Logo")
            gr.Image(
                value=logo_path,
            )
            # write info about the model
            gr.Markdown(
                "### Model Info\n"
                "This model is a 150M paramater LLaMA2 model trained from scratch on Azerbaijani text. It can be used to generate text based on the given prompt. "
            )
        with gr.Column(scale=0.6):
            gr.Markdown("### Chat with the Assistant")
            history = gr.Textbox(
                label="Chat History", value="", lines=20, interactive=False
            )
            question = gr.Textbox(
                label="Your question",
                placeholder="Type your question and press enter",
            )
            send_button = gr.Button("Send")
        with gr.Column(scale=0.2, min_width=200):
            gr.Markdown("### Model Settings")
            temperature = gr.Slider(
                minimum=0.1, maximum=1.0, value=0.9, label="Temperature"
            )
            gr.Markdown(
                "Controls the randomness of predictions. Lower values make the model more deterministic."
            )
            top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, label="Top P")
            gr.Markdown(
                "Nucleus sampling. Lower values focus on more likely predictions."
            )
            repetition_penalty = gr.Slider(
                minimum=1.0, maximum=2.0, value=1.2, label="Repetition Penalty"
            )
            gr.Markdown(
                "Penalizes repeated words. Higher values discourage repetition."
            )
            top_k = gr.Slider(minimum=0, maximum=100, value=50, label="Top K")
            gr.Markdown("Keeps only the top k predictions. Set to 0 for no limit.")

    question.submit(send_action)

    send_button.click(
        fn=answer_question,
        inputs=[history, temperature, top_p, repetition_penalty, top_k, question],
        outputs=[history, question],
    )

app.launch()