demo / main.py
lixiang6's picture
Upload 19 files
15f87d2 verified
raw
history blame contribute delete
727 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import io
import base64
app = FastAPI()
# 🔧 加载模型(你可以替换成你本地修改后的 StoryDiffusion Pipeline)
pipe = StableDiffusionPipeline.from_pretrained("你的模型路径").to("cuda")
class PromptInput(BaseModel):
prompt: str
@app.post("/generate")
def generate_image(data: PromptInput):
image = pipe(data.prompt).images[0]
# 把图像编码为 base64,便于前端显示
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image_base64": img_str}