from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from huggingface_hub import hf_hub_download from pydantic import BaseModel import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model from tensorflow.keras.preprocessing.image import img_to_array from tensorflow.keras.applications.efficientnet import preprocess_input from PIL import Image import json import io # ==== CONFIG ==== REPO_ID = "MAS-AI-0000/GameNet-1" MODEL_FILENAME = "GameNetModel.h5" #MODEL_FILENAME = "GameNetModel.keras" LABELS_FILENAME = "label_to_index.json" GENRE_FILENAME = "game_genre_map.json" IMG_SIZE = (300, 300) # ==== Load assets ==== model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME) labels_path = hf_hub_download(repo_id=REPO_ID, filename=LABELS_FILENAME) genre_path = hf_hub_download(repo_id=REPO_ID, filename=GENRE_FILENAME) model = load_model(model_path) with open(labels_path, "r") as f: label_to_index = json.load(f) index_to_label = {v: k for k, v in label_to_index.items()} with open(genre_path, "r") as f: genre_map = json.load(f) # ==== FastAPI Setup ==== app = FastAPI() # Optional: CORS if frontend is on different domain app.add_middleware( CORSMiddleware, allow_origins=["*"], # change this in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Response schema class Prediction(BaseModel): game: str genre: str confidence: float # Inference route @app.post("/predict", response_model=Prediction) async def predict(file: UploadFile = File(...)): try: # Step 1: Load image image_bytes = await file.read() img = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Step 2: Resize for EfficientNetB3 (300x300) img = img.resize(IMG_SIZE, Image.Resampling.BICUBIC) # Step 3: Convert to array and preprocess arr = img_to_array(img) arr = preprocess_input(arr) # normalize like in Colab arr = np.expand_dims(arr, axis=0) # Step 4: Inference preds = model.predict(arr) class_idx = int(np.argmax(preds)) confidence = float(np.max(preds)) # Step 5: Get label and genre game = index_to_label.get(class_idx, "Unknown") genre = genre_map.get(game, "Unknown") return Prediction(game=game, genre=genre, confidence=confidence) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500)