rahulshah63 commited on
Commit
9203178
·
1 Parent(s): aa3d0ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -7,19 +7,29 @@ import matplotlib.pyplot as plt
7
  device="cpu"
8
  bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
9
  processor = bundle.get_text_processor()
10
- tacotron2 = bundle.get_tacotron2().to(device)
 
 
 
 
 
 
11
 
12
  # Load Weights and bias of nepali text
13
  checkpoint_path = os.path.join(os.getcwd(), 'model_E45.ckpt')
14
- checkpoint = torch.load(checkpoint_path, map_location=device)
 
 
 
 
15
 
16
- tts_state_dict_stage1 = {key.replace("conv.", ""): value for key, value in checkpoint.items()}
17
- tts_state_dict_stage2 = {key.replace("linear_layer.", ""): value for key, value in tts_state_dict_stage1.items()}
18
 
19
- # decoder.attention_layer.location_layer.location_weight --> decoder.attention_layer.location_layer.location_conv.weight
20
- tts_state_dict_stage3 = {key.replace("location_weight", "location_conv.weight"): value for key, value in tts_state_dict_stage2.items()}
21
 
22
- tacotron2.load_state_dict(tts_state_dict_stage3)
23
 
24
  # Workaround to load model mapped on GPU
25
  # https://stackoverflow.com/a/61840832
 
7
  device="cpu"
8
  bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
9
  processor = bundle.get_text_processor()
10
+ # tacotron2 = bundle.get_tacotron2().to(device)
11
+
12
+ tacotron2 = torch.hub.load(
13
+ "NVIDIA/DeepLearningExamples:torchhub",
14
+ "nvidia_tacotron2",
15
+ pretrained=False,
16
+ )
17
 
18
  # Load Weights and bias of nepali text
19
  checkpoint_path = os.path.join(os.getcwd(), 'model_E45.ckpt')
20
+ state_dict = torch.load(checkpoint_path, map_location=device)
21
+
22
+ tacotron2.load_state_dict(state_dict)
23
+ tacotron2 = tacotron2.to(device)
24
+ tacotron2.eval()
25
 
26
+ # tts_state_dict_stage1 = {key.replace("conv.", ""): value for key, value in checkpoint.items()}
27
+ # tts_state_dict_stage2 = {key.replace("linear_layer.", ""): value for key, value in tts_state_dict_stage1.items()}
28
 
29
+ # # decoder.attention_layer.location_layer.location_weight --> decoder.attention_layer.location_layer.location_conv.weight
30
+ # tts_state_dict_stage3 = {key.replace("location_weight", "location_conv.weight"): value for key, value in tts_state_dict_stage2.items()}
31
 
32
+ # tacotron2.load_state_dict(tts_state_dict_stage3)
33
 
34
  # Workaround to load model mapped on GPU
35
  # https://stackoverflow.com/a/61840832