Spaces:
Running
Running
File size: 2,163 Bytes
cfb3fab bf12e57 cfb3fab bf12e57 8397178 cfb3fab bf12e57 cfb3fab bf12e57 a56d89f bf12e57 a56d89f bf12e57 a56d89f bf12e57 a56d89f 8397178 bf12e57 cfb3fab 6b3c2db bf12e57 afcdacc bf12e57 cfb3fab db457c9 cfb3fab bf12e57 7f4fed9 bf12e57 cfb3fab 7f4fed9 cfb3fab 54fea3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoImageProcessor
from PIL import Image
from rembg import remove
import gradio as gr
import spaces
import io
import numpy as np
# Load the Nomic embed model
processor = AutoImageProcessor.from_pretrained("nomic-ai/nomic-embed-vision-v1.5")
vision_model = AutoModel.from_pretrained("nomic-ai/nomic-embed-vision-v1.5", trust_remote_code=True)
def focus_on_subject(image: Image.Image) -> Image.Image:
"""
Remove background and crop to the main object using rembg.
Args:
image (PIL.Image.Image): Input image.
Returns:
PIL.Image.Image: Cropped image with background removed.
"""
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
image = image.convert("RGB")
# Remove background
img_bytes = io.BytesIO()
image.save(img_bytes, format="PNG")
img_bytes = img_bytes.getvalue()
result_bytes = remove(img_bytes)
result_image = Image.open(io.BytesIO(result_bytes)).convert("RGBA")
bbox = result_image.getbbox()
cropped = result_image.crop(bbox) if bbox else result_image
return cropped.convert("RGB")
def ImgEmbed(image: Image.Image)-> list[float]:
"""
Preprocess image, generate normalized embedding, and return both embedding and processed image.
Args:
image (PIL.Image.Image): Input image.
Returns:
List[float]: Embedding vector.
"""
focused_image = focus_on_subject(image)
inputs = processor(focused_image, return_tensors="pt")
img_emb = vision_model(**inputs).last_hidden_state
img_embeddings = F.normalize(img_emb[:, 0], p=2, dim=1)
return img_embeddings[0].tolist()
# Gradio UI
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
img = gr.Image(label="Upload Image")
btn = gr.Button("Get Embeddings")
with gr.Column():
# pre_img = gr.Image(label="Preprocessed Image")
out = gr.Text(label="Image Embedding")
btn.click(ImgEmbed, inputs=[img], outputs=[out])
if __name__ == "__main__":
demo.launch(mcp_server=True)
|