Tonic commited on
Commit
b75aa1a
·
verified ·
1 Parent(s): 2b25213

add temperature

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -19,18 +19,19 @@ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, us
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
21
 
22
- def generate_text(prompt, max_length, repetition_penalty):
23
  # Tokenize the input and create attention mask
24
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
25
  input_ids = inputs.input_ids
26
  attention_mask = inputs.attention_mask
27
 
28
- # Generate the text using the model, with the attention mask
29
  outputs = model.generate(
30
  input_ids,
31
  attention_mask=attention_mask,
32
  max_length=max_length,
33
  repetition_penalty=repetition_penalty,
 
34
  pad_token_id=tokenizer.eos_token_id
35
  )
36
 
@@ -48,12 +49,13 @@ with gr.Blocks() as demo:
48
  prompt = gr.Textbox(label="Enter your prompt here:", placeholder="Today I planned to ...")
49
  max_length = gr.Slider(label="Max Length", minimum=70, maximum=1000, step=50, value=200)
50
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2)
 
51
  submit_button = gr.Button("Generate")
52
 
53
  with gr.Column():
54
  output = gr.Textbox(label="✒️Inkuba.4B:")
55
 
56
- submit_button.click(generate_text, inputs=[prompt, max_length, repetition_penalty], outputs=output)
57
 
58
  # Launch the demo
59
- demo.launch()
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
21
 
22
+ def generate_text(prompt, max_length, repetition_penalty, temperature):
23
  # Tokenize the input and create attention mask
24
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
25
  input_ids = inputs.input_ids
26
  attention_mask = inputs.attention_mask
27
 
28
+ # Generate the text using the model, with the attention mask and temperature
29
  outputs = model.generate(
30
  input_ids,
31
  attention_mask=attention_mask,
32
  max_length=max_length,
33
  repetition_penalty=repetition_penalty,
34
+ temperature=temperature,
35
  pad_token_id=tokenizer.eos_token_id
36
  )
37
 
 
49
  prompt = gr.Textbox(label="Enter your prompt here:", placeholder="Today I planned to ...")
50
  max_length = gr.Slider(label="Max Length", minimum=70, maximum=1000, step=50, value=200)
51
  repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.1, value=1.2)
52
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.5) # Added slider for temperature
53
  submit_button = gr.Button("Generate")
54
 
55
  with gr.Column():
56
  output = gr.Textbox(label="✒️Inkuba.4B:")
57
 
58
+ submit_button.click(generate_text, inputs=[prompt, max_length, repetition_penalty, temperature], outputs=output)
59
 
60
  # Launch the demo
61
+ demo.launch()