aai / tabs /images /models.py
barreloflube's picture
Refactor import statement in models.py to use double quotes for string interpolation
e63972b
raw
history blame
2.49 kB
from typing import List, Optional, Dict, Any
import gradio as gr
from pydantic import BaseModel, field_validator
from PIL import Image
from config import Config as appConfig
class ControlNetReq(BaseModel):
controlnets: List[str] # ["canny", "tile", "depth", "scribble"]
control_images: List[Image.Image]
controlnet_conditioning_scale: List[float]
class Config:
arbitrary_types_allowed=True
class BaseReq(BaseModel):
model: str = ""
prompt: str = ""
negative_prompt: Optional[str] = None
fast_generation: Optional[bool] = True
loras: Optional[list] = []
embeddings: Optional[list] = None
resize_mode: Optional[str] = "resize_and_fill" # resize_only, crop_and_resize, resize_and_fill
scheduler: Optional[str] = "euler_fl"
height: int = 1024
width: int = 1024
num_images_per_prompt: int = 1
num_inference_steps: int = 8
clip_skip: Optional[int] = None
guidance_scale: float = 3.5
seed: Optional[int] = 0
refiner: bool = False
vae: bool = True
controlnet_config: Optional[ControlNetReq] = None
custom_addons: Optional[Dict[Any, Any]] = None
class Config:
arbitrary_types_allowed=True
@field_validator('model', 'negative_prompt', 'embeddings', 'clip_skip', 'controlnet_config')
def check_model(cls, values):
for m in appConfig.IMAGES_MODELS:
gr.Info(f"{m.get('repo_id')} {values.get('model')}")
if m.get('repo_id') == values.get('model'):
loader = m.get('loader')
if loader == "flux" and values.get('negative_prompt'):
raise ValueError("Negative prompt is not supported for Flux models.")
if loader == "flux" and values.get('embeddings'):
raise ValueError("Embeddings are not supported for Flux models.")
if loader == "flux" and values.get('clip_skip'):
raise ValueError("Clip skip is not supported for Flux models.")
if loader == "flux" and values.get('controlnet_config'):
if "scribble" in values.get('controlnet_config').controlnets:
raise ValueError("Scribble is not supported for Flux models.")
return values
class BaseImg2ImgReq(BaseReq):
image: Image.Image
strength: float = 1.0
class Config:
arbitrary_types_allowed=True
class BaseInpaintReq(BaseImg2ImgReq):
mask_image: Image.Image
class Config:
arbitrary_types_allowed=True