VietCat commited on
Commit
2cbdbe8
·
1 Parent(s): 767e943

improve response time

Browse files
Files changed (1) hide show
  1. app.py +26 -18
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import time
3
  import warnings
 
4
  warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
5
 
6
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
@@ -8,6 +9,13 @@ import torch
8
  import gradio as gr
9
  import psutil
10
 
 
 
 
 
 
 
 
11
  # Load model and tokenizer
12
  model_id = "NlpHUST/gpt2-vietnamese"
13
  try:
@@ -28,23 +36,20 @@ model.to(device)
28
  model.eval()
29
 
30
  # Print device and memory info for debugging
31
- print(f"---------- Info -----------")
32
  print(f"Device: {device}")
33
  print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
34
-
35
- def print_system_resources():
36
- cpu_percent = psutil.cpu_percent(interval=1)
37
- memory = psutil.virtual_memory()
38
- print(f"CPU usage: {cpu_percent}%")
39
- print(f"Memory usage: {memory.percent}% ({memory.used/1e9:.2f}/{memory.total/1e9:.2f} GB)")
40
-
41
- # Call before generation
42
  print_system_resources()
43
- print(f"--------------------------")
44
 
45
- def generate_text(prompt, max_length=50, temperature=1.0):
 
 
 
 
 
 
46
  try:
47
  start_time = time.time()
 
48
  # Encode input with attention mask
49
  inputs = tokenizer(
50
  prompt,
@@ -58,18 +63,21 @@ def generate_text(prompt, max_length=50, temperature=1.0):
58
  outputs = model.generate(
59
  input_ids=inputs["input_ids"],
60
  attention_mask=inputs["attention_mask"],
61
- max_new_tokens=30, # Limit new tokens to reduce computation
 
62
  temperature=temperature,
63
- do_sample=True,
64
- num_beams=3, # Use beam search for faster generation
65
- no_repeat_ngram_size=2, # Prevent repetitive phrases
66
  pad_token_id=tokenizer.pad_token_id,
67
- early_stopping=True # Stop when generation is complete
68
  )
69
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
70
  elapsed_time = time.time() - start_time
 
71
  print(f"Generation time: {elapsed_time:.2f} seconds")
72
- return generated_text
73
  except Exception as e:
74
  return f"Error generating text: {e}"
75
 
@@ -79,7 +87,7 @@ demo = gr.Interface(
79
  inputs=[
80
  gr.Textbox(label="Nhập văn bản đầu vào", placeholder="Viết gì đó bằng tiếng Việt..."),
81
  gr.Slider(20, 100, value=50, step=10, label="Độ dài tối đa"),
82
- gr.Slider(0.5, 1.5, value=1.0, step=0.1, label="Nhiệt độ (Temperature)")
83
  ],
84
  outputs="text",
85
  title="Sinh văn bản tiếng Việt",
 
1
  import os
2
  import time
3
  import warnings
4
+ import re
5
  warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
6
 
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
9
  import gradio as gr
10
  import psutil
11
 
12
+ # Print system resources for debugging
13
+ def print_system_resources():
14
+ cpu_percent = psutil.cpu_percent(interval=1)
15
+ memory = psutil.virtual_memory()
16
+ print(f"CPU usage: {cpu_percent}%")
17
+ print(f"Memory usage: {memory.percent}% ({memory.used/1e9:.2f}/{memory.total/1e9:.2f} GB)")
18
+
19
  # Load model and tokenizer
20
  model_id = "NlpHUST/gpt2-vietnamese"
21
  try:
 
36
  model.eval()
37
 
38
  # Print device and memory info for debugging
 
39
  print(f"Device: {device}")
40
  print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
 
 
 
 
 
 
 
 
41
  print_system_resources()
 
42
 
43
+ def clean_text(text):
44
+ """Clean generated text by removing non-alphabetic characters and extra spaces."""
45
+ text = re.sub(r'[^\w\s.,!?]', '', text) # Remove non-alphabetic characters
46
+ text = re.sub(r'\s+', ' ', text).strip() # Normalize spaces
47
+ return text
48
+
49
+ def generate_text(prompt, max_length=50, temperature=0.7):
50
  try:
51
  start_time = time.time()
52
+ print_system_resources() # Print resources before generation
53
  # Encode input with attention mask
54
  inputs = tokenizer(
55
  prompt,
 
63
  outputs = model.generate(
64
  input_ids=inputs["input_ids"],
65
  attention_mask=inputs["attention_mask"],
66
+ max_new_tokens=30,
67
+ min_length=10, # Ensure minimum output length
68
  temperature=temperature,
69
+ do_sample=False, # Use greedy decoding for consistency
70
+ num_beams=3, # Use beam search for better quality
71
+ no_repeat_ngram_size=2,
72
  pad_token_id=tokenizer.pad_token_id,
73
+ early_stopping=True
74
  )
75
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
76
+ cleaned_text = clean_text(generated_text)
77
  elapsed_time = time.time() - start_time
78
+ print_system_resources()
79
  print(f"Generation time: {elapsed_time:.2f} seconds")
80
+ return cleaned_text
81
  except Exception as e:
82
  return f"Error generating text: {e}"
83
 
 
87
  inputs=[
88
  gr.Textbox(label="Nhập văn bản đầu vào", placeholder="Viết gì đó bằng tiếng Việt..."),
89
  gr.Slider(20, 100, value=50, step=10, label="Độ dài tối đa"),
90
+ gr.Slider(0.5, 1.0, value=0.7, step=0.1, label="Nhiệt độ (Temperature)")
91
  ],
92
  outputs="text",
93
  title="Sinh văn bản tiếng Việt",