shot-categorizer-v0 / README.md
sayakpaul's picture
sayakpaul HF Staff
Update README.md
d9b7396 verified
|
raw
history blame
2.25 kB
metadata
license: mit
language:
  - en
base_model:
  - microsoft/Florence-2-large

Shot Categorizer 🎬

Shot categorization model finetuned from the microsoft/Florence-2-large model. This model can be used to obtain metadata information about shots which can further be used to curate datasets of different kinds.

Training configuration:

  • Batch size: 16
  • Gradient accumulation steps: 4
  • Learning rate: 1e-6
  • Epochs: 20
  • Max grad norm: 1.0
  • Hardware: 8xH100s

Training was conducted using FP16 mixed-precision and DeepSpeed Zero2 scheme. The vision tower of the model was kept frozen during the training.

Inference

from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
import requests


folder_path = "diffusers-internal-dev/shot-categorizer-v0"
model = (
    AutoModelForCausalLM.from_pretrained(folder_path, torch_dtype=torch.float16, trust_remote_code=True)
    .to("cuda")
    .eval()
)
processor = AutoProcessor.from_pretrained(folder_path, trust_remote_code=True)

prompts = ["<COLOR>", "<LIGHTING>", "<LIGHTING_TYPE>", "<COMPOSITION>"]
url = "diffusers-internal-dev/shot-categorizer-v0/resolve/main/assets/image_3.jpg"
image = Image.open(img_path).convert("RGB")

with torch.no_grad() and torch.inference_mode():
    for prompt in prompts:
        inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda", torch.float16)
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            early_stopping=False,
            do_sample=False,
            num_beams=3,
        )
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        parsed_answer = processor.post_process_generation(
            generated_text, task=prompt, image_size=(image.width, image.height)
        )
        print(parsed_answer)

Should print:

{'<COLOR>': 'Cool, Saturated, Cyan, Blue'}
{'<LIGHTING>': 'Soft light, Low contrast'}
{'<LIGHTING_TYPE>': 'Daylight, Sunny'}
{'<COMPOSITION>': 'Left heavy'}