Saad0KH's picture
Update app.py
549db45 verified
raw
history blame
3.87 kB
from flask import Flask, request, jsonify
from PIL import Image
import base64
from io import BytesIO
import numpy as np
import insightface
import onnxruntime as ort
import huggingface_hub
from SegCloth import segment_clothing
from transparent_background import Remover
import threading
import logging
app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
# Load the model lazily
model = None
detector = None
def load_model():
global model, detector
path = huggingface_hub.hf_hub_download("public-data/insightface", "models/scrfd_person_2.5g.onnx")
options = ort.SessionOptions()
options.intra_op_num_threads = 8
options.inter_op_num_threads = 8
session = ort.InferenceSession(
path, sess_options=options, providers=["CPUExecutionProvider", "CUDAExecutionProvider"]
)
model = insightface.model_zoo.retinaface.RetinaFace(model_file=path, session=session)
model.prepare(-1, nms_thresh=0.5, input_size=(640, 640))
detector = model
logging.info("Model loaded successfully.")
# Function to decode a base64 image to PIL.Image.Image
def decode_image_from_base64(image_data):
image_data = base64.b64decode(image_data)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
# Function to encode a PIL image to base64
def encode_image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="JPEG") # Use JPEG for potentially better performance
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def remove_background(image):
remover = Remover()
if isinstance(image, Image.Image):
output = remover.process(image)
elif isinstance(image, np.ndarray):
image_pil = Image.fromarray(image)
output = remover.process(image_pil)
else:
raise TypeError("Unsupported image type")
return output
def detect_and_segment_persons(image, clothes):
img = np.array(image)
img = img[:, :, ::-1] # RGB -> BGR
if detector is None:
load_model() # Ensure the model is loaded
bboxes, kpss = detector.detect(img)
if bboxes.shape[0] == 0:
return [encode_image_to_base64(remove_background(image))]
height, width, _ = img.shape
bboxes = np.round(bboxes[:, :4]).astype(int)
bboxes[:, 0] = np.clip(bboxes[:, 0], 0, width)
bboxes[:, 1] = np.clip(bboxes[:, 1], 0, height)
bboxes[:, 2] = np.clip(bboxes[:, 2], 0, width)
bboxes[:, 3] = np.clip(bboxes[:, 3], 0, height)
all_segmented_images = []
for i in range(bboxes.shape[0]):
bbox = bboxes[i]
x1, y1, x2, y2 = bbox
person_img = img[y1:y2, x1:x2]
pil_img = Image.fromarray(person_img[:, :, ::-1])
img_rm_background = remove_background(pil_img)
segmented_result = segment_clothing(img_rm_background, clothes)
all_segmented_images.extend(segmented_result)
return all_segmented_images
@app.route('/', methods=['GET'])
def welcome():
return "Welcome to Clothing Segmentation API"
@app.route('/api/detect', methods=['POST'])
def detect():
try:
data = request.json
image_base64 = data['image']
image = decode_image_from_base64(image_base64)
clothes = ["Upper-clothes", "Skirt", "Pants", "Dress"]
# Run the detection and segmentation in a separate thread
result = []
def process_image():
nonlocal result
result = detect_and_segment_persons(image, clothes)
thread = threading.Thread(target=process_image)
thread.start()
thread.join() # Wait for the thread to finish
return jsonify({'images': result})
except Exception as e:
logging.error(f"Error occurred: {e}")
return jsonify({'error': str(e)}), 500
if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=7860)