File size: 6,772 Bytes
08c9cf1
 
 
 
 
 
 
 
 
 
f91423b
08c9cf1
 
 
 
aec7cfd
dd45c6d
938c288
72681ac
dd45c6d
 
aec7cfd
dd45c6d
aec7cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ce6e9b
aec7cfd
c0d825f
aec7cfd
 
 
 
 
 
 
 
 
c0d825f
aec7cfd
 
 
 
 
 
 
52d383f
aec7cfd
 
c0d825f
aec7cfd
 
 
31ab05a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aec7cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import logging
import os
import io
import tempfile
import requests
from PIL import Image
import cv2
from urllib.parse import urlparse
import socket
import ipaddress
import spaces

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Resource Limits
MAX_IMAGE_MB = 50
MAX_IMAGE_RES = (4096, 4096)
MAX_VIDEO_MB = 50
MAX_VIDEO_DURATION = 30  # seconds

@spaces.GPU
def fetch_media_from_url(url):
    """
    Downloads media from a URL. Supports images and videos.
    Returns PIL.Image or video file path.
    """
    logger.info(f"Fetching media from URL: {url}")
    if not is_public_ip(url):
        logger.warning("Blocked non-public URL request (possible SSRF).")
        return None

    try:
        parsed_url = urlparse(url)
        ext = os.path.splitext(parsed_url.path)[-1].lower()
        headers = {"User-Agent": "Mozilla/5.0"}
        r = requests.get(url, headers=headers, timeout=10)

        if r.status_code != 200 or len(r.content) > 50 * 1024 * 1024:
            logger.warning(f"Download failed or file too large.")
            return None

        tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=ext)
        tmp_file.write(r.content)
        tmp_file.close()

        if ext in [".jpg", ".jpeg", ".png"]:
            return Image.open(tmp_file.name).convert("RGB")
        elif ext in [".mp4", ".avi", ".mov"]:
            return tmp_file.name
        else:
            logger.warning("Unsupported file type from URL.")
            return None
    except Exception as e:
        logger.error(f"URL fetch failed: {e}")
        return None

# Input Validation Functions
def validate_image(img):
    """
    Validates the uploaded image based on size and resolution limits.

    Args:
        img (PIL.Image.Image): Image to validate.

    Returns:
        Tuple[bool, str or None]: (True, None) if valid; (False, reason) otherwise.
    """
    logger.info("Validating uploaded image.")
    try:
        buffer = io.BytesIO()
        img.save(buffer, format="PNG")
        size_mb = len(buffer.getvalue()) / (1024 * 1024)

        if size_mb > MAX_IMAGE_MB:
            logger.warning("Image exceeds size limit of 5MB.")
            return False, "Image exceeds 5MB limit."

        if img.width > MAX_IMAGE_RES[0] or img.height > MAX_IMAGE_RES[1]:
            logger.warning("Image resolution exceeds 1920x1080.")
            return False, "Image resolution exceeds 1920x1080."

        logger.info("Image validation passed.")
        return True, None
    except Exception as e:
        logger.error(f"Error validating image: {e}")
        return False, str(e)

def validate_video(path):
    """
    Validates the uploaded video based on size and duration limits.

    Args:
        path (str): Path to the video file.

    Returns:
        Tuple[bool, str or None]: (True, None) if valid; (False, reason) otherwise.
    """
    logger.info(f"Validating video file at: {path}")
    try:
        size_mb = os.path.getsize(path) / (1024 * 1024)
        if size_mb > MAX_VIDEO_MB:
            logger.warning("Video exceeds size limit of 50MB.")
            return False, "Video exceeds 50MB limit."

        cap = cv2.VideoCapture(path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
        duration = frames / fps if fps else 0
        cap.release()

        if duration > MAX_VIDEO_DURATION:
            logger.warning("Video exceeds 30 seconds duration limit.")
            return False, "Video exceeds 30 seconds duration limit."

        logger.info("Video validation passed.")
        return True, None
    except Exception as e:
        logger.error(f"Error validating video: {e}")
        return False, str(e)
    
# Input Resolution
def resolve_input(mode, media_upload, url):
    """
    Resolves the media input based on selected mode.
    - If mode is 'Upload', accepts either:
        * 1–5 images (PIL.Image)
        * OR 1 video file (file path as string)
    - If mode is 'URL', fetches remote image or video.
    
    Args:
        mode (str): 'Upload' or 'URL'
        media_upload (List[Union[PIL.Image.Image, str]]): Uploaded media
        url (str): URL to image or video

    Returns:
        List[Union[PIL.Image.Image, str]] or None
    """
    try:
        logger.info(f"Resolving input for mode: {mode}")
        logger.info(f"Raw uploaded input: {media_upload}")

        if mode == "Upload":
            if not media_upload:
                logger.warning("No upload detected.")
                return None

            # Gradio gives file paths, so open and classify them
            image_files = []
            video_files = []

            for file in media_upload:
                if isinstance(file, str):
                    ext = file.lower().split('.')[-1]
                    if ext in ['jpg', 'jpeg', 'png']:
                        try:
                            img = Image.open(file).convert("RGB")
                            image_files.append(img)
                        except Exception as e:
                            logger.warning(f"Failed to open image: {file} - {e}")
                    elif ext in ['mp4', 'avi', 'mov']:
                        video_files.append(file)

            # Only one type of input allowed
            if image_files and video_files:
                logger.warning("Mixed media upload not supported (images + video).")
                return None

            if image_files:
                if 1 <= len(image_files) <= 5:
                    logger.info(f"Accepted {len(image_files)} image(s).")
                    return image_files
                logger.warning("Invalid number of images. Must be 1 to 5.")
                return None

            if video_files:
                if len(video_files) == 1:
                    logger.info("Accepted single video upload.")
                    return video_files
                logger.warning("Only one video allowed.")
                return None

            logger.warning("Unsupported upload type.")
            return None

        elif mode == "URL":
            if not url:
                logger.warning("URL mode selected but URL is empty.")
                return None
            media = fetch_media_from_url(url)
            if media:
                logger.info("Media successfully fetched from URL.")
                return [media]
            else:
                logger.warning("Failed to resolve media from URL.")
                return None

        else:
            logger.error(f"Invalid mode selected: {mode}")
            return None

    except Exception as e:
        logger.error(f"Exception in resolve_input(): {e}")
        return None