VietCat commited on
Commit
5655ff3
·
1 Parent(s): 19763d1

adjust generation time

Browse files
Files changed (1) hide show
  1. app.py +17 -19
app.py CHANGED
@@ -24,6 +24,7 @@ def print_system_resources():
24
  print(f"CPU usage: {cpu_percent}%")
25
  print(f"Memory usage: {min(memory.used / (mem_limit * 1e9) * 100, 100):.1f}% ({memory.used/1e9:.2f}/{mem_limit:.2f} GB)")
26
  print(f"Active processes: {len(psutil.pids())}")
 
27
 
28
  # Print Gradio version for debugging
29
  print(f"Gradio version: {gr.__version__}")
@@ -47,12 +48,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
  model.to(device)
48
  model.eval()
49
 
50
- # Apply quantization to reduce memory and speed up
51
- model = torch.quantization.quantize_dynamic(
52
- model, {torch.nn.Linear}, dtype=torch.qint8
53
- )
54
- print(f"Model quantized: {model.__class__.__name__}")
55
-
56
  # Print device and memory info for debugging
57
  print(f"Device: {device}")
58
  print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
@@ -64,16 +59,18 @@ def clean_text(text):
64
  text = re.sub(r'\s+', ' ', text).strip()
65
  return text
66
 
67
- def generate_text(prompt, temperature=0.3, max_new_tokens=35, top_k=15, max_length=40):
68
  try:
69
  start_time = time.time()
70
  # Debug memory before generation
71
  print("Memory before generation:")
72
  print_system_resources()
73
  # Fixed parameters
74
- repetition_penalty = 1.2
 
 
75
  # Log parameters
76
- print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, repetition_penalty={repetition_penalty}")
77
  # Encode input with attention mask
78
  encode_time = time.time()
79
  inputs = tokenizer(
@@ -85,8 +82,8 @@ def generate_text(prompt, temperature=0.3, max_new_tokens=35, top_k=15, max_leng
85
  ).to(device)
86
  print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
87
  print(f"Input tokens: {len(inputs['input_ids'][0])}")
88
- # Define EOS token IDs for '.', '!', '?'
89
- eos_token_ids = [tokenizer.encode(s)[0] for s in ['.', '!', '?']]
90
  print(f"EOS token IDs: {eos_token_ids}")
91
  # Generate text
92
  gen_time = time.time()
@@ -94,13 +91,14 @@ def generate_text(prompt, temperature=0.3, max_new_tokens=35, top_k=15, max_leng
94
  input_ids=inputs["input_ids"],
95
  attention_mask=inputs["attention_mask"],
96
  max_new_tokens=int(max_new_tokens),
97
- min_length=3,
98
  do_sample=True,
99
  top_k=int(top_k),
 
100
  temperature=temperature,
101
  repetition_penalty=repetition_penalty,
102
  pad_token_id=tokenizer.pad_token_id,
103
- eos_token_id=eos_token_ids
104
  )
105
  gen_duration = time.time() - gen_time
106
  print(f"Generation time: {gen_duration:.2f} seconds")
@@ -130,16 +128,16 @@ demo = gr.Interface(
130
  gr.Textbox(
131
  label="Nhập văn bản đầu vào",
132
  placeholder="Viết gì đó bằng tiếng Việt...",
133
- value="Hôm nay là một ngày đẹp trời" # Default text
134
  ),
135
- gr.Slider(0.1, 0.7, value=0.3, step=0.1, label="Nhiệt độ (Temperature, 0.1-0.3 cho tốc độ nhanh và mạch lạc, 0.4-0.7 cho đa dạng hơn)"),
136
- gr.Slider(25, 100, value=35, step=5, label="Số token mới tối đa (max_new_tokens, 25-50 cho tốc độ nhanh, 75-100 cho câu dài hơn)"),
137
- gr.Slider(5, 30, value=15, step=5, label="Top K (top_k, 5-15 cho văn bản mạch lạc, 20-30 cho đa dạng hơn)"),
138
- gr.Slider(30, 100, value=40, step=5, label="Độ dài tối đa (max_length, 30-50 cho tốc độ nhanh, 75-100 cho câu dài hơn)")
139
  ],
140
  outputs="text",
141
  title="Sinh văn bản tiếng Việt",
142
- description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.1-0.3, max_new_tokens 25-50, top_k 5-15, và max_length 30-50 để đạt thời gian <2 giây và văn bản mạch lạc. Dùng temperature 0.4-0.7, max_new_tokens 75-100, top_k 20-30, max_length 75-100 cho câu dàiđa dạng hơn.",
143
  allow_flagging="never"
144
  )
145
 
 
24
  print(f"CPU usage: {cpu_percent}%")
25
  print(f"Memory usage: {min(memory.used / (mem_limit * 1e9) * 100, 100):.1f}% ({memory.used/1e9:.2f}/{mem_limit:.2f} GB)")
26
  print(f"Active processes: {len(psutil.pids())}")
27
+ print(f"Python memory (torch): {torch.cuda.memory_allocated()/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
28
 
29
  # Print Gradio version for debugging
30
  print(f"Gradio version: {gr.__version__}")
 
48
  model.to(device)
49
  model.eval()
50
 
 
 
 
 
 
 
51
  # Print device and memory info for debugging
52
  print(f"Device: {device}")
53
  print(f"Memory allocated: {torch.cuda.memory_allocated(device)/1e9:.2f} GB" if torch.cuda.is_available() else "CPU only")
 
59
  text = re.sub(r'\s+', ' ', text).strip()
60
  return text
61
 
62
+ def generate_text(prompt, temperature=0.9, max_new_tokens=100, top_k=30, max_length=100):
63
  try:
64
  start_time = time.time()
65
  # Debug memory before generation
66
  print("Memory before generation:")
67
  print_system_resources()
68
  # Fixed parameters
69
+ repetition_penalty = 1.0
70
+ top_p = 0.9
71
+ min_length = 50 # Force longer output
72
  # Log parameters
73
+ print(f"Parameters: max_length={max_length}, temperature={temperature}, max_new_tokens={max_new_tokens}, top_k={top_k}, top_p={top_p}, repetition_penalty={repetition_penalty}, min_length={min_length}")
74
  # Encode input with attention mask
75
  encode_time = time.time()
76
  inputs = tokenizer(
 
82
  ).to(device)
83
  print(f"Encoding time: {time.time() - encode_time:.2f} seconds")
84
  print(f"Input tokens: {len(inputs['input_ids'][0])}")
85
+ # Avoid stopping at '.', '!', '?' to allow longer text
86
+ eos_token_ids = [] # Disable EOS tokens
87
  print(f"EOS token IDs: {eos_token_ids}")
88
  # Generate text
89
  gen_time = time.time()
 
91
  input_ids=inputs["input_ids"],
92
  attention_mask=inputs["attention_mask"],
93
  max_new_tokens=int(max_new_tokens),
94
+ min_length=min_length,
95
  do_sample=True,
96
  top_k=int(top_k),
97
+ top_p=top_p,
98
  temperature=temperature,
99
  repetition_penalty=repetition_penalty,
100
  pad_token_id=tokenizer.pad_token_id,
101
+ eos_token_id=eos_token_ids if eos_token_ids else None
102
  )
103
  gen_duration = time.time() - gen_time
104
  print(f"Generation time: {gen_duration:.2f} seconds")
 
128
  gr.Textbox(
129
  label="Nhập văn bản đầu vào",
130
  placeholder="Viết gì đó bằng tiếng Việt...",
131
+ value="Hôm nay là một ngày đẹp trời, hãy kể một câu chuyện dài về trải nghiệm của bạn." # Updated prompt
132
  ),
133
+ gr.Slider(0.1, 0.9, value=0.9, step=0.1, label="Nhiệt độ (Temperature, 0.1-0.3 cho mạch lạc, 0.7-0.9 cho câu dài và sáng tạo)"),
134
+ gr.Slider(25, 100, value=100, step=5, label="Số token mới tối đa (max_new_tokens, 25-50 cho câu ngắn, 75-100 cho câu dài)"),
135
+ gr.Slider(5, 30, value=30, step=5, label="Top K (top_k, 5-10 cho mạch lạc, 20-30 cho đa dạng và dài hơn)"),
136
+ gr.Slider(30, 100, value=100, step=5, label="Độ dài tối đa (max_length, 30-50 cho câu ngắn, 75-100 cho câu dài)")
137
  ],
138
  outputs="text",
139
  title="Sinh văn bản tiếng Việt",
140
+ description="Dùng mô hình GPT-2 Vietnamese từ NlpHUST để sinh văn bản tiếng Việt. Chọn temperature 0.7-0.9, max_new_tokens 75-100, top_k 20-30, và max_length 75-100 để tạo câu dài và đa dạng. Dùng temperature 0.1-0.3, max_new_tokens 25-50, top_k 5-10, và max_length 30-50 để đạt thời gian <2 giây và văn bản mạch lạc. Lưu ý: Để tạo câu dài, dùng prompt yêu cầu câu chuyện hoặc mô tả chi tiết, thử temperature=0.9, top_k=30, min_length=50.",
141
  allow_flagging="never"
142
  )
143