VietCat commited on
Commit
937d274
·
1 Parent(s): abe4cc3

adjust generation time

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import time
3
  import warnings
4
  import re
 
5
  warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
6
 
7
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
@@ -16,9 +17,9 @@ def print_system_resources():
16
  # Get container memory limit (for Docker)
17
  try:
18
  with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
19
- mem_limit = int(f.read().strip()) / 1e9 # Convert to GB
20
  except:
21
- mem_limit = 16.0 # Fallback for HFS free (16GB)
22
  print(f"Total physical memory (psutil): {memory.total/1e9:.2f} GB")
23
  print(f"Container memory limit: {mem_limit:.2f} GB")
24
  print(f"CPU usage: {cpu_percent}%")
@@ -67,12 +68,12 @@ def clean_text(text):
67
  if complete_sentences:
68
  text = ' '.join(complete_sentences)
69
  else:
70
- # Fallback: Keep until last valid word if no complete sentence
71
  words = text.split()
72
  text = ' '.join(words[:-1]) if len(words) > 1 else text
73
  return text
74
 
75
- def generate_text(prompt, max_length=50, temperature=0.9):
76
  try:
77
  start_time = time.time()
78
  print_system_resources()
@@ -84,18 +85,18 @@ def generate_text(prompt, max_length=50, temperature=0.9):
84
  truncation=True,
85
  max_length=max_length
86
  ).to(device)
 
87
 
88
  # Generate text
89
  outputs = model.generate(
90
  input_ids=inputs["input_ids"],
91
  attention_mask=inputs["attention_mask"],
92
- max_new_tokens=15, # Reduce for speed
93
- min_length=10,
94
  do_sample=True,
95
- top_k=30, # Reduce for speed
96
- top_p=0.9,
97
  temperature=temperature,
98
- no_repeat_ngram_size=2,
99
  pad_token_id=tokenizer.pad_token_id
100
  )
101
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -104,6 +105,9 @@ def generate_text(prompt, max_length=50, temperature=0.9):
104
  cleaned_text = clean_text(generated_text)
105
  elapsed_time = time.time() - start_time
106
  print(f"Generation time: {elapsed_time:.2f} seconds")
 
 
 
107
  return cleaned_text
108
  except Exception as e:
109
  return f"Error generating text: {e}"
@@ -117,8 +121,8 @@ demo = gr.Interface(
117
  placeholder="Viết gì đó bằng tiếng Việt...",
118
  value="Hôm nay là một ngày đẹp trời" # Default text
119
  ),
120
- gr.Slider(20, 100, value=50, step=10, label="Độ dài tối đa"),
121
- gr.Slider(0.7, 1.0, value=0.9, step=0.1, label="Nhiệt độ (Temperature)")
122
  ],
123
  outputs="text",
124
  title="Sinh văn bản tiếng Việt",
 
2
  import time
3
  import warnings
4
  import re
5
+ import gc
6
  warnings.filterwarnings("ignore", category=UserWarning, module="torch._utils")
7
 
8
  from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
17
  # Get container memory limit (for Docker)
18
  try:
19
  with open('/sys/fs/cgroup/memory/memory.limit_in_bytes', 'r') as f:
20
+ mem_limit = min(int(f.read().strip()) / 1e9, 16.0) # Cap at 16GB for HFS free
21
  except:
22
+ mem_limit = 16.0 # Fallback for HFS free
23
  print(f"Total physical memory (psutil): {memory.total/1e9:.2f} GB")
24
  print(f"Container memory limit: {mem_limit:.2f} GB")
25
  print(f"CPU usage: {cpu_percent}%")
 
68
  if complete_sentences:
69
  text = ' '.join(complete_sentences)
70
  else:
71
+ # Fallback: Keep until last valid word
72
  words = text.split()
73
  text = ' '.join(words[:-1]) if len(words) > 1 else text
74
  return text
75
 
76
+ def generate_text(prompt, max_length=40, temperature=1.0):
77
  try:
78
  start_time = time.time()
79
  print_system_resources()
 
85
  truncation=True,
86
  max_length=max_length
87
  ).to(device)
88
+ print(f"Input tokens: {len(inputs['input_ids'][0])}")
89
 
90
  # Generate text
91
  outputs = model.generate(
92
  input_ids=inputs["input_ids"],
93
  attention_mask=inputs["attention_mask"],
94
+ max_new_tokens=10, # Reduce for speed
95
+ min_length=5,
96
  do_sample=True,
97
+ top_k=20, # Reduce for speed
98
+ top_p=0.85, # Prioritize high-quality tokens
99
  temperature=temperature,
 
100
  pad_token_id=tokenizer.pad_token_id
101
  )
102
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
105
  cleaned_text = clean_text(generated_text)
106
  elapsed_time = time.time() - start_time
107
  print(f"Generation time: {elapsed_time:.2f} seconds")
108
+ # Clear memory cache
109
+ gc.collect()
110
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
111
  return cleaned_text
112
  except Exception as e:
113
  return f"Error generating text: {e}"
 
121
  placeholder="Viết gì đó bằng tiếng Việt...",
122
  value="Hôm nay là một ngày đẹp trời" # Default text
123
  ),
124
+ gr.Slider(20, 100, value=40, step=10, label="Độ dài tối đa"),
125
+ gr.Slider(0.7, 1.0, value=1.0, step=0.1, label="Nhiệt độ (Temperature)")
126
  ],
127
  outputs="text",
128
  title="Sinh văn bản tiếng Việt",