lighthouse-emnlp2024 commited on
Commit
6d412a7
·
1 Parent(s): 1d740ea

Fix app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -22
app.py CHANGED
@@ -46,24 +46,10 @@ def load_pretrained_weights():
46
  )
47
  )
48
  for file_url in tqdm(file_urls):
49
- if not os.path.exists("gradio_demo/weights/" + os.path.basename(file_url)):
50
- command = "wget -P gradio_demo/weights/ {}".format(file_url)
51
  subprocess.run(command, shell=True)
52
 
53
- # Slowfast weights
54
- if not os.path.exists("SLOWFAST_8x8_R50.pkl"):
55
- subprocess.run(
56
- "wget https://dl.fbaipublicfiles.com/pyslowfast/model_zoo/kinetics400/SLOWFAST_8x8_R50.pkl",
57
- shell=True,
58
- )
59
-
60
- # PANNs weights
61
- if not os.path.exists("Cnn14_mAP=0.431.pth"):
62
- subprocess.run(
63
- "wget https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth",
64
- shell=True,
65
- )
66
-
67
  return file_urls
68
 
69
 
@@ -79,7 +65,7 @@ Model initialization
79
  """
80
  load_pretrained_weights()
81
  model = CGDETRPredictor(
82
- "gradio_demo/weights/clip_cg_detr_qvhighlight.ckpt",
83
  device=device,
84
  feature_name="clip",
85
  slowfast_path=None,
@@ -147,11 +133,9 @@ def model_load(radio, video):
147
  raise gr.Error("Select from the models")
148
 
149
  model = model_class(
150
- "gradio_demo/weights/{}_{}_qvhighlight.ckpt".format(feature, model_name),
151
  device=device,
152
  feature_name="{}".format(feature),
153
- slowfast_path="SLOWFAST_8x8_R50.pkl",
154
- pann_path="Cnn14_mAP=0.431.pth",
155
  )
156
 
157
  load_finished_msg = "Model loaded: {}".format(radio)
@@ -222,7 +206,7 @@ def predict(textbox, line, gallery):
222
  for i, (second, score) in enumerate(
223
  zip(highlighted_seconds, highlighted_scores)
224
  ):
225
- output_path = "gradio_demo/highlight_frames/highlight_{}.png".format(i)
226
  (
227
  ffmpeg.input(loaded_video_path, ss=second)
228
  .output(output_path, vframes=1, qscale=2)
@@ -341,7 +325,7 @@ def main():
341
  ],
342
  )
343
 
344
- demo.launch()
345
 
346
 
347
  if __name__ == "__main__":
 
46
  )
47
  )
48
  for file_url in tqdm(file_urls):
49
+ if not os.path.exists("weights/" + os.path.basename(file_url)):
50
+ command = "wget -P weights/ {}".format(file_url)
51
  subprocess.run(command, shell=True)
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  return file_urls
54
 
55
 
 
65
  """
66
  load_pretrained_weights()
67
  model = CGDETRPredictor(
68
+ "weights/clip_cg_detr_qvhighlight.ckpt",
69
  device=device,
70
  feature_name="clip",
71
  slowfast_path=None,
 
133
  raise gr.Error("Select from the models")
134
 
135
  model = model_class(
136
+ "weights/{}_{}_qvhighlight.ckpt".format(feature, model_name),
137
  device=device,
138
  feature_name="{}".format(feature),
 
 
139
  )
140
 
141
  load_finished_msg = "Model loaded: {}".format(radio)
 
206
  for i, (second, score) in enumerate(
207
  zip(highlighted_seconds, highlighted_scores)
208
  ):
209
+ output_path = "highlight_frames/highlight_{}.png".format(i)
210
  (
211
  ffmpeg.input(loaded_video_path, ss=second)
212
  .output(output_path, vframes=1, qscale=2)
 
325
  ],
326
  )
327
 
328
+ demo.launch(share=True, server_name="0.0.0.0")
329
 
330
 
331
  if __name__ == "__main__":