File size: 1,711 Bytes
4649ee7
 
 
 
 
 
 
 
 
c14636b
4649ee7
c14636b
4649ee7
 
 
 
 
 
fd11e35
 
4649ee7
15f1cce
4649ee7
 
06b4881
4649ee7
 
 
c14636b
 
4649ee7
 
 
 
c14636b
 
4649ee7
 
 
 
02a8a8a
0b48e62
 
 
 
 
 
 
4649ee7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
import torch

from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

#
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoPeftModelForCausalLM.from_pretrained(
    "Someman/bloomz-560m-fine-tuned-adapters_v2.0"
).to(device)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")

SUMMARIZATION = "Summarization"
TITLE_GENERATION = "Title Generation"


def generate_output(prompt, input, kwargs):
    text = input+prompt
    inputs = tokenizer(text, return_tensors="pt").to(device)
    generate = model.generate(**inputs, **kwargs)
    output = tokenizer.batch_decode(
        generate[:, inputs.input_ids.shape[1] :], skip_special_tokens=True
    )
    return output[0].split("\n")[0].strip()


def summarization(input: str):
    prompt = " \\nSummary in the same language as the doc:"
    kwargs = {"max_new_tokens": 50}
    return generate_output(prompt, input, kwargs)


def title_generation(input: str):
    prompt = "\\n\\nGive me a good title for the article above."
    kwargs = {"max_new_tokens": 50}
    return generate_output(prompt, input, kwargs)


def generate(task: str, input: str):
    if len(input) > 20:
        if task == SUMMARIZATION:
            return summarization(input)
        elif task == TITLE_GENERATION:
            return title_generation(input)
        else:
            return "Wow! Very Dangerous!"
    return "Enter something meaningful."


demo = gr.Interface(
    generate,
    [
        gr.Dropdown(
            [SUMMARIZATION, TITLE_GENERATION],
            label="Task",
            info="Will add more task later!",
        ),
        gr.TextArea(),
    ],
    outputs="text",
)
demo.launch()