Update pages/GPT.py
Browse files- pages/GPT.py +12 -11
pages/GPT.py
CHANGED
@@ -19,17 +19,18 @@ def preprocess_text(text_input, tokenizer):
|
|
19 |
prompt = tokenizer.encode(text_input, return_tensors='pt')
|
20 |
|
21 |
def predict_sentiment(model, prompt, temp, num_generate):
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
|
|
33 |
return result
|
34 |
|
35 |
st.title('Text generation with dreambook')
|
|
|
19 |
prompt = tokenizer.encode(text_input, return_tensors='pt')
|
20 |
|
21 |
def predict_sentiment(model, prompt, temp, num_generate):
|
22 |
+
with torch.inference_mode():
|
23 |
+
result = model.generate(
|
24 |
+
input_ids=prompt,
|
25 |
+
max_length=100,
|
26 |
+
num_beams=5,
|
27 |
+
do_sample=True,
|
28 |
+
temperature=float(temp),
|
29 |
+
top_k=50,
|
30 |
+
top_p=0.6,
|
31 |
+
no_repeat_ngram_size=3,
|
32 |
+
num_return_sequences=num_generate,
|
33 |
+
).cpu().numpy()
|
34 |
return result
|
35 |
|
36 |
st.title('Text generation with dreambook')
|