File size: 5,818 Bytes
20593ca
f1fd55f
 
 
20593ca
7de892f
de2aa9b
3cd966d
6ce3893
caa1229
85d769e
6ce3893
b2c8d9d
bedd188
20593ca
 
 
 
b2c8d9d
 
 
6ce3893
ba4d1a9
6b518fa
20593ca
d82698b
 
cefe7bc
f7ecb0c
 
d82698b
03eb2f0
f7ecb0c
20593ca
03eb2f0
20593ca
a63ebbb
a3504d6
5e5187b
 
a63ebbb
6ce3893
20593ca
 
6ce3893
85d769e
6ce3893
dded11c
a63ebbb
20593ca
 
a63ebbb
20593ca
 
 
dded11c
20593ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15a2bfc
20593ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7119068
20593ca
 
f1fd55f
20593ca
 
 
 
 
 
a63ebbb
20593ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4ae210
20593ca
 
a63ebbb
20593ca
 
 
 
 
 
 
 
ce6434e
a63ebbb
20593ca
927e8af
339d0f6
 
 
927e8af
a3504d6
a63ebbb
20593ca
 
 
 
 
 
a63ebbb
 
 
 
a3504d6
a63ebbb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#testing webm hello hello
#okay why is webm taking that much bruh
#can someone actually help me please, i dont know whats wrong with the code but it just is so slow when using remover.process(img, type='rgba')
#but type=color is still fast (used by mp4)

import spaces
import gradio as gr
import cv2
import numpy as np
import time
import random
from PIL import Image
import torch
import re
import os
import shutil
import subprocess
import tempfile

torch.jit.script = lambda f: f

from transparent_background import Remover

@spaces.GPU(duration=90)
def doo(video, color, mode, out_format, progress=gr.Progress()):
    print(str(color))
    if str(color).startswith('#'):
        color = color.lstrip('#')
        rgb = tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
        color = str(list(rgb))
    elif str(color).startswith('rgba'):
        rgba_match = re.match(r'rgba\(([\d.]+), ([\d.]+), ([\d.]+), [\d.]+\)', color)
        if rgba_match:
            r, g, b = rgba_match.groups() 
            color = str([int(float(r)), int(float(g)), int(float(b))])
    print("Parsed color:", color)
    if mode == 'Fast':
        remover = Remover(mode='fast')
    else:
        remover = Remover()

    cap = cv2.VideoCapture(video)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  
    fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
    writer = None
    tmpname = random.randint(111111111, 999999999)
    processed_frames = 0
    start_time = time.time()

    mp4_path = str(tmpname) + '.mp4'
    webm_path = str(tmpname) + '.webm'

    if out_format == 'mp4':
        while cap.isOpened():
            ret, frame = cap.read()

            if ret is False:
                break

            if time.time() - start_time >= 20 * 60 - 5:
                print("GPU Timeout is coming")
                cap.release()
                if writer is not None:
                    writer.release()
                return mp4_path

            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(frame).convert('RGB')

            if writer is None:
                writer = cv2.VideoWriter(mp4_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, img.size)

            processed_frames += 1
            print(f"Processing frame {processed_frames}")
            progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")

            out = remover.process(img, type=color)

            frame_bgr = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR)
            writer.write(frame_bgr)

        cap.release()
        if writer is not None:
            writer.release()
        return mp4_path

    else:
        temp_dir = tempfile.mkdtemp(prefix=f"tb_{tmpname}_")
        try:
            frame_idx = 0
            while cap.isOpened():
                ret, frame = cap.read()

                if ret is False:
                    break

                if time.time() - start_time >= 20 * 60 - 5:
                    print("GPU Timeout is coming")
                    cap.release()
                    # cleanup
                    shutil.rmtree(temp_dir, ignore_errors=True)
                    return webm_path

                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                img = Image.fromarray(frame).convert('RGB')

                processed_frames += 1
                frame_idx += 1
                print(f"Processing frame {processed_frames}")
                progress(processed_frames / total_frames, desc=f"Processing frame {processed_frames}/{total_frames}")

                out = remover.process(img, type='rgba') #thing that makes the process go slow
                out = out.convert('RGBA')

                frame_name = os.path.join(temp_dir, f"frame_{frame_idx:06d}.png")
                out.save(frame_name, 'PNG')

            cap.release()

            fr_str = str(int(round(fps))) if fps > 0 else "25"
            pattern = os.path.join(temp_dir, "frame_%06d.png")
            ffmpeg_cmd = [
                "ffmpeg", "-y",
                "-framerate", fr_str,
                "-i", pattern,
                "-i", str(video),
                "-map", "0:v",
                "-map", "1:a?",
                "-c:v", "libvpx-vp9",
                "-pix_fmt", "yuva420p",
                "-auto-alt-ref", "0",
                "-metadata:s:v:0", "alpha_mode=1",
                "-c:a", "libopus",
                "-shortest",
                webm_path
            ]
            print("Running ffmpeg:", " ".join(ffmpeg_cmd))
            subprocess.run(ffmpeg_cmd, check=True)

            shutil.rmtree(temp_dir, ignore_errors=True)
            return webm_path

        except subprocess.CalledProcessError as e:
            print("ffmpeg failed:", e)
            shutil.rmtree(temp_dir, ignore_errors=True)
            return webm_path
        except Exception as e:
            print("Error during processing:", e)
            shutil.rmtree(temp_dir, ignore_errors=True)
            raise

title = "🎞️ Video Background Removal Tool 🎥"
description = """*Please note that if your video file is long (has a high number of frames), there is a chance that processing break due to GPU timeout. In this case, consider trying Fast mode.*"""

examples = [
    ['./input.mp4', '#00FF00', 'Normal', 'mp4'],
]

iface = gr.Interface(
    fn=doo,
    inputs=[
        "video",
        gr.ColorPicker(label="Background color", value="#00FF00"),
        gr.components.Radio(['Normal', 'Fast'], label='Select mode', value='Normal', info='Normal is more accurate, but takes longer. | Fast has lower accuracy so the process will be faster.'),
        gr.components.Radio(['mp4', 'webm'], label='Output format', value='mp4')
    ],
    outputs="video",
    examples=examples,
    title=title,
    description=description
)
iface.launch()