JinhuaL1ANG commited on
Commit
c7710a2
·
verified ·
1 Parent(s): ffc11cc

Update src/demo/utils.py

Browse files
Files changed (1) hide show
  1. src/demo/utils.py +19 -8
src/demo/utils.py CHANGED
@@ -44,12 +44,16 @@ def process_audio(model, audio, config):
44
 
45
 
46
  def get_spec_pil(model, audio, config):
47
- fbank, spec_plot = process_audio(model, audio, config)
48
- buf = io.BytesIO()
49
- spec_plot.figure.savefig(buf, format='png')
50
- buf.seek(0)
51
- pil_spec = Image.open(buf)
52
- plt.close()
 
 
 
 
53
  return fbank, pil_spec
54
 
55
 
@@ -74,7 +78,8 @@ def get_mask_region(img):
74
  # Use the channel of opacity as mask
75
  mask = pil_to_tensor(layers[0])[-1,:,:] # RGBA
76
  mask = mask.permute(1, 0) # (F, T) -> (T, F)
77
-
 
78
  mask = (mask > 0).float()
79
 
80
  # Rescale mask to spectrum size
@@ -85,6 +90,8 @@ def get_mask_region(img):
85
  def get_mask_regions(img):
86
  def _prepare_mask(m):
87
  m = m.permute(1, 0)
 
 
88
  m = (m > 0).float()
89
  m = F.interpolate(m.unsqueeze(0).unsqueeze(0), SPEC_RES).squeeze()
90
  return m
@@ -112,7 +119,8 @@ def update_reference_spec(ref_spec_pil_ori, mask_src, dt, df, resize_scale_t, re
112
  if mask_src is not None:
113
  mask_ref = get_edit_mask(
114
  mask_src, dx=df, dy=dt,
115
- resize_scale_x=resize_scale_f, resize_scale_y=resize_scale_t,
 
116
  )
117
  mask_ref = mask_ref.float() # match the PIL format, channel last
118
  mask_ref_pil = F.interpolate(mask_ref.unsqueeze(0).unsqueeze(0), DESPLAY_RES).squeeze()
@@ -121,6 +129,9 @@ def update_reference_spec(ref_spec_pil_ori, mask_src, dt, df, resize_scale_t, re
121
  if mask_ref_pil.ndim > 2:
122
  mask_ref_pil = mask_ref_pil.squeeze()
123
  mask_ref_pil = mask_ref_pil.permute(1, 0)
 
 
 
124
 
125
  # Convert to PIL
126
  mask_ref_pil = to_pil_image(mask_ref_pil).convert("L")
 
44
 
45
 
46
  def get_spec_pil(model, audio, config):
47
+ try:
48
+ fbank, spec_plot = process_audio(model, audio, config)
49
+ buf = io.BytesIO()
50
+ spec_plot.figure.savefig(buf, format='png')
51
+ buf.seek(0)
52
+ pil_spec = Image.open(buf)
53
+ plt.close()
54
+ except:
55
+ print("Warning: the streaming is not ready. Please repeate uploading again.")
56
+ fbank, pil_spec = None, None
57
  return fbank, pil_spec
58
 
59
 
 
78
  # Use the channel of opacity as mask
79
  mask = pil_to_tensor(layers[0])[-1,:,:] # RGBA
80
  mask = mask.permute(1, 0) # (F, T) -> (T, F)
81
+ # Flip the freq axis to ensure the orignal point on the top left
82
+ mask = mask.flip(1)
83
  mask = (mask > 0).float()
84
 
85
  # Rescale mask to spectrum size
 
90
  def get_mask_regions(img):
91
  def _prepare_mask(m):
92
  m = m.permute(1, 0)
93
+ # Flip the freq axis to ensure the orignal point on the top left
94
+ m = m.flip(1)
95
  m = (m > 0).float()
96
  m = F.interpolate(m.unsqueeze(0).unsqueeze(0), SPEC_RES).squeeze()
97
  return m
 
119
  if mask_src is not None:
120
  mask_ref = get_edit_mask(
121
  mask_src, dx=df, dy=dt,
122
+ resize_scale_x=resize_scale_f,
123
+ resize_scale_y=resize_scale_t,
124
  )
125
  mask_ref = mask_ref.float() # match the PIL format, channel last
126
  mask_ref_pil = F.interpolate(mask_ref.unsqueeze(0).unsqueeze(0), DESPLAY_RES).squeeze()
 
129
  if mask_ref_pil.ndim > 2:
130
  mask_ref_pil = mask_ref_pil.squeeze()
131
  mask_ref_pil = mask_ref_pil.permute(1, 0)
132
+ # De-flip freq exis to match pil imshow style
133
+ mask_ref_pil = mask_ref_pil.flip(0)
134
+ mask_ref_pil = mask_ref_pil * 0.5 # for transparency
135
 
136
  # Convert to PIL
137
  mask_ref_pil = to_pil_image(mask_ref_pil).convert("L")