VietCat commited on
Commit
c0af0d6
·
1 Parent(s): 937d274

adjust generation time

Browse files
Files changed (1) hide show
  1. app.py +15 -14
app.py CHANGED
@@ -52,6 +52,7 @@ model.eval()
52
  model = torch.quantization.quantize_dynamic(
53
  model, {torch.nn.Linear}, dtype=torch.qint8
54
  )
 
55
 
56
  # Print device and memory info for debugging
57
  print(f"Device: {device}")
@@ -62,18 +63,20 @@ def clean_text(text):
62
  """Clean generated text by removing non-alphabetic characters and incomplete sentences."""
63
  text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text)
64
  text = re.sub(r'\s+', ' ', text).strip()
65
- # Keep only complete sentences (ending with punctuation)
66
  sentences = re.split(r'(?<=[.!?])\s+', text)
67
  complete_sentences = [s for s in sentences if re.search(r'[.!?]$', s)]
68
  if complete_sentences:
69
- text = ' '.join(complete_sentences)
70
- else:
71
- # Fallback: Keep until last valid word
72
- words = text.split()
73
- text = ' '.join(words[:-1]) if len(words) > 1 else text
74
- return text
 
 
75
 
76
- def generate_text(prompt, max_length=40, temperature=1.0):
77
  try:
78
  start_time = time.time()
79
  print_system_resources()
@@ -91,11 +94,10 @@ def generate_text(prompt, max_length=40, temperature=1.0):
91
  outputs = model.generate(
92
  input_ids=inputs["input_ids"],
93
  attention_mask=inputs["attention_mask"],
94
- max_new_tokens=10, # Reduce for speed
95
- min_length=5,
96
  do_sample=True,
97
- top_k=20, # Reduce for speed
98
- top_p=0.85, # Prioritize high-quality tokens
99
  temperature=temperature,
100
  pad_token_id=tokenizer.pad_token_id
101
  )
@@ -107,7 +109,6 @@ def generate_text(prompt, max_length=40, temperature=1.0):
107
  print(f"Generation time: {elapsed_time:.2f} seconds")
108
  # Clear memory cache
109
  gc.collect()
110
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
111
  return cleaned_text
112
  except Exception as e:
113
  return f"Error generating text: {e}"
@@ -121,7 +122,7 @@ demo = gr.Interface(
121
  placeholder="Viết gì đó bằng tiếng Việt...",
122
  value="Hôm nay là một ngày đẹp trời" # Default text
123
  ),
124
- gr.Slider(20, 100, value=40, step=10, label="Độ dài tối đa"),
125
  gr.Slider(0.7, 1.0, value=1.0, step=0.1, label="Nhiệt độ (Temperature)")
126
  ],
127
  outputs="text",
 
52
  model = torch.quantization.quantize_dynamic(
53
  model, {torch.nn.Linear}, dtype=torch.qint8
54
  )
55
+ print(f"Model quantized: {model.__class__.__name__}")
56
 
57
  # Print device and memory info for debugging
58
  print(f"Device: {device}")
 
63
  """Clean generated text by removing non-alphabetic characters and incomplete sentences."""
64
  text = re.sub(r'[^\w\s.,!?àáâãèéêìíòóôõùúýăđĩũơưạảấầẩẫậắằẳẵặẹẻẽếềểễệỉịọỏốồổỗộớờởỡợụủứừửữựỳỵỷỹ]', '', text)
65
  text = re.sub(r'\s+', ' ', text).strip()
66
+ # Keep only complete sentences (ending with .!?)
67
  sentences = re.split(r'(?<=[.!?])\s+', text)
68
  complete_sentences = [s for s in sentences if re.search(r'[.!?]$', s)]
69
  if complete_sentences:
70
+ return ' '.join(complete_sentences)
71
+ # Fallback: Cut at last comma if present
72
+ last_comma = text.rfind(',')
73
+ if last_comma != -1:
74
+ return text[:last_comma].strip()
75
+ # Fallback: Cut at last valid word
76
+ words = text.split()
77
+ return ' '.join(words[:-1]) if len(words) > 1 else text
78
 
79
+ def generate_text(prompt, max_length=30, temperature=1.0):
80
  try:
81
  start_time = time.time()
82
  print_system_resources()
 
94
  outputs = model.generate(
95
  input_ids=inputs["input_ids"],
96
  attention_mask=inputs["attention_mask"],
97
+ max_new_tokens=8, # Reduce for speed
98
+ min_length=3,
99
  do_sample=True,
100
+ top_k=15, # Reduce for speed
 
101
  temperature=temperature,
102
  pad_token_id=tokenizer.pad_token_id
103
  )
 
109
  print(f"Generation time: {elapsed_time:.2f} seconds")
110
  # Clear memory cache
111
  gc.collect()
 
112
  return cleaned_text
113
  except Exception as e:
114
  return f"Error generating text: {e}"
 
122
  placeholder="Viết gì đó bằng tiếng Việt...",
123
  value="Hôm nay là một ngày đẹp trời" # Default text
124
  ),
125
+ gr.Slider(20, 100, value=30, step=10, label="Độ dài tối đa"),
126
  gr.Slider(0.7, 1.0, value=1.0, step=0.1, label="Nhiệt độ (Temperature)")
127
  ],
128
  outputs="text",