thunnai commited on
Commit
a223259
·
1 Parent(s): 07bc70a

Refactor model initialization and GPU handling in generate function

Browse files

- Simplify model initialization logic
- Ensure model is initialized with correct device
- Remove redundant GPU logging
- Improve device selection for model loading

Files changed (1) hide show
  1. webui.py +9 -5
webui.py CHANGED
@@ -50,14 +50,15 @@ def generate(text,
50
  """Generate audio from text."""
51
 
52
  global MODEL
53
- model = MODEL
54
 
55
- if model is None:
56
- raise RuntimeError("Model not initialized. Please ensure the model is loaded before generating audio.")
 
 
 
57
 
58
  # if gpu available, move model to gpu
59
  if torch.cuda.is_available():
60
- print("Moving model to GPU")
61
  model = model.to("cuda")
62
 
63
  with torch.no_grad():
@@ -117,9 +118,12 @@ def build_ui(model_dir, device=0):
117
 
118
  global MODEL
119
 
120
- # Initialize model
 
121
  if MODEL is None:
122
  MODEL = initialize_model(model_dir, device=device)
 
 
123
 
124
  # Define callback function for voice cloning
125
  def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):
 
50
  """Generate audio from text."""
51
 
52
  global MODEL
 
53
 
54
+ # Initialize model if not already done
55
+ if MODEL is None:
56
+ MODEL = initialize_model(device="cuda" if torch.cuda.is_available() else "cpu")
57
+
58
+ model = MODEL
59
 
60
  # if gpu available, move model to gpu
61
  if torch.cuda.is_available():
 
62
  model = model.to("cuda")
63
 
64
  with torch.no_grad():
 
118
 
119
  global MODEL
120
 
121
+ # Initialize model with proper device handling
122
+ device = "cuda" if torch.cuda.is_available() and device != "cpu" else "cpu"
123
  if MODEL is None:
124
  MODEL = initialize_model(model_dir, device=device)
125
+ if device == "cuda":
126
+ MODEL = MODEL.to(device)
127
 
128
  # Define callback function for voice cloning
129
  def voice_clone(text, prompt_text, prompt_wav_upload, prompt_wav_record):