Spaces:
Running
Running
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModel, AutoImageProcessor | |
from PIL import Image | |
from rembg import remove | |
import gradio as gr | |
import spaces | |
import io | |
import numpy as np | |
# Load the Nomic embed model | |
processor = AutoImageProcessor.from_pretrained("nomic-ai/nomic-embed-vision-v1.5") | |
vision_model = AutoModel.from_pretrained("nomic-ai/nomic-embed-vision-v1.5", trust_remote_code=True) | |
def focus_on_subject(image: Image.Image) -> Image.Image: | |
""" | |
Remove background and crop to the main object using rembg. | |
Args: | |
image (PIL.Image.Image): Input image. | |
Returns: | |
PIL.Image.Image: Cropped image with background removed. | |
""" | |
if isinstance(image, np.ndarray): | |
image = Image.fromarray(image) | |
image = image.convert("RGB") | |
# Remove background | |
img_bytes = io.BytesIO() | |
image.save(img_bytes, format="PNG") | |
img_bytes = img_bytes.getvalue() | |
result_bytes = remove(img_bytes) | |
result_image = Image.open(io.BytesIO(result_bytes)).convert("RGBA") | |
bbox = result_image.getbbox() | |
cropped = result_image.crop(bbox) if bbox else result_image | |
return cropped.convert("RGB") | |
def ImgEmbed(image: Image.Image)-> list[float]: | |
""" | |
Preprocess image, generate normalized embedding, and return both embedding and processed image. | |
Args: | |
image (PIL.Image.Image): Input image. | |
Returns: | |
List[float]: Embedding vector. | |
""" | |
focused_image = focus_on_subject(image) | |
inputs = processor(focused_image, return_tensors="pt") | |
img_emb = vision_model(**inputs).last_hidden_state | |
img_embeddings = F.normalize(img_emb[:, 0], p=2, dim=1) | |
return img_embeddings[0].tolist() | |
# Gradio UI | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
img = gr.Image(label="Upload Image") | |
btn = gr.Button("Get Embeddings") | |
with gr.Column(): | |
# pre_img = gr.Image(label="Preprocessed Image") | |
out = gr.Text(label="Image Embedding") | |
btn.click(ImgEmbed, inputs=[img], outputs=[out]) | |
if __name__ == "__main__": | |
demo.launch(mcp_server=True) | |