InferBench / benchmark /metrics /image_reward.py
nifleisch
feat: add core logic for project
2c50826
raw
history blame
631 Bytes
import os
import tempfile
from typing import Dict
import ImageReward as RM
from PIL import Image
class ImageRewardMetric:
def __init__(self):
self.model = RM.load("ImageReward-v1.0")
@property
def name(self) -> str:
return "image_reward"
def compute_score(
self,
image: Image.Image,
prompt: str,
) -> Dict[str, float]:
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
image.save(tmp.name)
score = self.model.score(prompt, [tmp.name])
os.unlink(tmp.name)
return {"image_reward": score}