Spaces:
Running
Running
add temperature
Browse files
app.py
CHANGED
@@ -19,18 +19,19 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, us
|
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
model.to(device)
|
21 |
|
22 |
-
def generate_text(prompt, max_length, repetition_penalty):
|
23 |
# Tokenize the input and create attention mask
|
24 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
25 |
input_ids = inputs.input_ids
|
26 |
attention_mask = inputs.attention_mask
|
27 |
|
28 |
-
# Generate the text using the model, with the attention mask
|
29 |
outputs = model.generate(
|
30 |
input_ids,
|
31 |
attention_mask=attention_mask,
|
32 |
max_length=max_length,
|
33 |
repetition_penalty=repetition_penalty,
|
|
|
34 |
pad_token_id=tokenizer.eos_token_id
|
35 |
)
|
36 |
|
@@ -48,12 +49,13 @@ with gr.Blocks() as demo:
|
|
48 |
prompt = gr.Textbox(label="Enter your prompt here:", placeholder="Today I planned to ...")
|
49 |
max_length = gr.Slider(label="Max Length", minimum=70, maximum=1000, step=50, value=200)
|
50 |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2)
|
|
|
51 |
submit_button = gr.Button("Generate")
|
52 |
|
53 |
with gr.Column():
|
54 |
output = gr.Textbox(label="✒️Inkuba.4B:")
|
55 |
|
56 |
-
submit_button.click(generate_text, inputs=[prompt, max_length, repetition_penalty], outputs=output)
|
57 |
|
58 |
# Launch the demo
|
59 |
-
demo.launch()
|
|
|
19 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
model.to(device)
|
21 |
|
22 |
+
def generate_text(prompt, max_length, repetition_penalty, temperature):
|
23 |
# Tokenize the input and create attention mask
|
24 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
25 |
input_ids = inputs.input_ids
|
26 |
attention_mask = inputs.attention_mask
|
27 |
|
28 |
+
# Generate the text using the model, with the attention mask and temperature
|
29 |
outputs = model.generate(
|
30 |
input_ids,
|
31 |
attention_mask=attention_mask,
|
32 |
max_length=max_length,
|
33 |
repetition_penalty=repetition_penalty,
|
34 |
+
temperature=temperature,
|
35 |
pad_token_id=tokenizer.eos_token_id
|
36 |
)
|
37 |
|
|
|
49 |
prompt = gr.Textbox(label="Enter your prompt here:", placeholder="Today I planned to ...")
|
50 |
max_length = gr.Slider(label="Max Length", minimum=70, maximum=1000, step=50, value=200)
|
51 |
repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2)
|
52 |
+
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.5) # Added slider for temperature
|
53 |
submit_button = gr.Button("Generate")
|
54 |
|
55 |
with gr.Column():
|
56 |
output = gr.Textbox(label="✒️Inkuba.4B:")
|
57 |
|
58 |
+
submit_button.click(generate_text, inputs=[prompt, max_length, repetition_penalty, temperature], outputs=output)
|
59 |
|
60 |
# Launch the demo
|
61 |
+
demo.launch()
|