Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, Request | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from playwright.async_api import async_playwright | |
import os | |
import time | |
from urllib.parse import urlparse | |
from typing import Optional, Dict, List | |
import logging | |
import json | |
import base64 | |
from io import BytesIO | |
import aiohttp | |
import traceback | |
import requests | |
from openai import OpenAI | |
from starlette.middleware.base import BaseHTTPMiddleware | |
import uvicorn | |
from collections import defaultdict | |
from PIL import Image | |
from Crypto.PublicKey import RSA | |
from Crypto.Cipher import PKCS1_v1_5 | |
from datetime import datetime | |
from pymongo.mongo_client import MongoClient | |
from pymongo.server_api import ServerApi | |
# mongodb uri | |
URI = os.getenv("URI") | |
# Create a new client and connect to the server | |
client = MongoClient(URI, server_api=ServerApi('1')) | |
app = FastAPI() | |
# 设置最大连接数 | |
MAX_CONNECTIONS = 100 | |
current_connections = 0 | |
# 设置优化设计接口的访问限制 | |
optimize_design_requests = defaultdict(int) # 记录每个IP的请求次数 | |
optimize_design_timestamps = defaultdict(float) # 记录每个IP的首次请求时间 | |
white_list = eval(os.getenv("WHITELIST")) | |
print(logging.INFO, white_list) | |
class ConnectionLimitMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request: Request, call_next): | |
global current_connections | |
if current_connections >= MAX_CONNECTIONS: | |
return JSONResponse( | |
status_code=503, | |
content={"detail": "已超过最大链接数,请稍后重试"} | |
) | |
current_connections += 1 | |
try: | |
response = await call_next(request) | |
return response | |
finally: | |
current_connections -= 1 | |
# 添加中间件 | |
app.add_middleware(ConnectionLimitMiddleware) | |
# 确保缓存目录存在 | |
CACHE_DIR = "cache" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# 挂载静态文件目录 | |
app.mount("/screenshots", StaticFiles(directory=CACHE_DIR), name="screenshots") | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# API Keys | |
SECRET_KEY = os.getenv("SECRET_KEY", "wangyue") | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
OPENAI_API_IMAGE_EDIT_KEY = os.getenv("OPENAI_API_IMAGE_EDIT_KEY") | |
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
SEARCH_ENGINE_ID = os.getenv("SEARCH_ENGINE_ID", "27acb0d55ad504716") | |
class ScreenshotRequest(BaseModel): | |
url: str | |
width: Optional[int] = 1024 | |
height: Optional[int] = 768 | |
format: Optional[str] = "png" | |
custom_headers: Optional[Dict[str, str]] = {} | |
class AnalysisRequest(BaseModel): | |
text: str | |
image_data: Optional[str] = None | |
request_model_id: str = 'gpt-4.1-mini' | |
class OptimizationRequest(BaseModel): | |
text: str | |
image_data: str | |
suggestions: List[str] | |
request_model_id: str = 'gpt-image-1' | |
openai_key: str = '' | |
user_key: str = '' | |
class CaseStudyRequest(BaseModel): | |
user_input: str | |
request_model_id: str = 'gpt-4.1-mini' | |
class TextOptimizationRequest(BaseModel): | |
original_feedback: str | |
user_input: str | |
request_model_id: str = 'gpt-4.1-mini' | |
class SearchRequest(BaseModel): | |
query: str | |
num_results: Optional[int] = 2 | |
async def capture_screenshot(request: ScreenshotRequest): | |
try: | |
if not request.url: | |
raise HTTPException(status_code=400, detail="需要提供URL参数") | |
# 生成唯一的文件名 | |
domain = urlparse(request.url).netloc.replace(".", "_") | |
timestamp = int(time.time() * 1000) | |
filename = f"{domain}_{timestamp}.{request.format}" | |
filepath = os.path.join(CACHE_DIR, filename) | |
print(f"开始为 {request.url} 生成截图...") | |
# 默认请求头 | |
default_headers = { | |
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36", | |
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8", | |
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8", | |
"Connection": "keep-alive", | |
"Cache-Control": "max-age=0", | |
"Sec-Fetch-Dest": "document", | |
"Sec-Fetch-Mode": "navigate", | |
"Sec-Fetch-Site": "none", | |
"Sec-Fetch-User": "?1", | |
"Upgrade-Insecure-Requests": "1" | |
} | |
# 合并默认请求头和自定义请求头 | |
headers = {**default_headers, **(request.custom_headers or {})} | |
async with async_playwright() as p: | |
browser = await p.chromium.launch() | |
page = await browser.new_page() | |
# 设置视口大小 | |
await page.set_viewport_size({ | |
"width": request.width, | |
"height": request.height | |
}) | |
# 设置请求头 | |
await page.set_extra_http_headers(headers) | |
# 访问页面 | |
await page.goto(request.url, wait_until="networkidle") | |
# 生成截图 | |
await page.screenshot(path=filepath, type=request.format) | |
await browser.close() | |
print(f"截图完成: {filepath}") | |
# 返回截图URL | |
screenshot_url = f"/screenshots/{filename}" | |
return JSONResponse({ | |
"success": True, | |
"imageUrl": screenshot_url, | |
"filename": filename | |
}) | |
except Exception as e: | |
logging.error(f"截图过程中出错: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"截图生成失败: {str(e)}") | |
async def health_check(): | |
return { | |
"status": "ok", | |
"current_connections": current_connections, | |
"max_connections": MAX_CONNECTIONS | |
} | |
def index() -> FileResponse: | |
return FileResponse(path="static/index.html", media_type="text/html") | |
async def analyze_feedback(request: AnalysisRequest): | |
try: | |
context = "这是用户上传的设计图。" if request.image_data else '' | |
context += "以下是老板对设计的反馈内容:\n" + request.text | |
context += '\n请根据老板的反馈分析情绪值(用emoji表示),并结合给出的设计图稿,给出三个具体的修改建议,最后给出合适的搜索关键词,以获取合适的参考设计案例。每个建议应该包含一个标题和详细描述。首先你需要对老板的情绪进行解读,使用"情绪值:"开头并分为五类:1. 非常满意-😊😊😊 2. 比较满意-🙂🙂🙂 3. 一般般-😐😐😐 4. 不太满意-🙁🙁🙁 5. 非常不满意-😠😠😠,然后在下一行用一句话分析老板的情绪,以"情绪分析:"开头。随后,请以"修改建议:\n"开头,并以有序列表分三行说明三个具体建议,并以"\n\n"分隔,比如:"1. 提高对比度:xxx\n\n 2. ...\n\n 3. ...\n\n"。记得结合图片进行分析和提出修改建议。最后,你需要使用网页搜索来获取合适的参考UI设计案例,来为修改当前案例提供参考。请在新的一行以"搜索内容:"开头,给出合适的搜索内容,以获取合适的参考设计案例,注意,你只能搜索UI界面或虚拟形象设计相关的案例。' | |
# 调用OpenAI API进行分析 | |
response = await call_openai_api( | |
system_prompt='你是一位专业的设计顾问,擅长分析客户反馈,提取关键信息,并提供专业建议,你需要严格遵循用户给出的输出格式要求。', | |
user_content=[ | |
*([{"type": "input_image", "image_url": request.image_data}] if request.image_data else []), | |
{"type": "input_text", "text": context} | |
], | |
request_model_id=request.request_model_id | |
) | |
return JSONResponse(response) | |
except Exception as e: | |
logging.error(f'Error: {e}, traceback: {traceback.format_exc()}') | |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}') | |
def decrypt_user_key(encrypted_key: str) -> dict: | |
try: | |
key = RSA.import_key(SECRET_KEY) | |
cipher = PKCS1_v1_5.new(key) | |
# Decode the base64 encrypted text | |
encrypted_bytes = base64.b64decode(encrypted_key) | |
# Decrypt the message | |
decrypted_bytes = cipher.decrypt(encrypted_bytes, None) | |
# Convert bytes to string | |
decrypted_text = decrypted_bytes.decode('utf-8') | |
# Try to parse as JSON for pretty printing | |
try: | |
data = json.loads(decrypted_text) | |
return data | |
except Exception as e: | |
logging.error(f'Error: {e}, traceback: {traceback.format_exc()}') | |
return {} | |
except Exception as e: | |
logging.error(f'Error decrypting user key: {e}') | |
return {} | |
def search_user_key(user_key: str) -> dict: | |
try: | |
# Create a new client and connect to the server | |
client = MongoClient(URI, server_api=ServerApi('1')) | |
# Get the user_info database and user_keys collection | |
user_info_db = client.user_info | |
user_keys_collection = user_info_db['user_keys'] | |
# Search for the user key | |
search_res = user_keys_collection.aggregate([ | |
{ | |
'$search': { | |
'index': "user_key", | |
'text': { | |
'query': user_key, | |
'path': { | |
'wildcard': "*" | |
} | |
} | |
} | |
}, | |
{ | |
'$limit': 3 | |
}, | |
{ | |
'$project': { | |
'_id': 0, | |
'user_key': 1, | |
'user_name': 1, | |
'credit': 1, | |
'expiration': 1 | |
} | |
} | |
]) | |
# Get the first result | |
result = next(search_res, None) | |
if result: | |
return result | |
else: | |
return None | |
except Exception as e: | |
logging.error(f'Error searching user key: {e}') | |
return None | |
async def optimize_design(request: OptimizationRequest): | |
try: | |
# 检查用户密钥 | |
user_key = None | |
if request.user_key: | |
user_info = decrypt_user_key(request.user_key) | |
if not user_info: | |
raise HTTPException( | |
status_code=400, | |
detail="无效的用户密钥" | |
) | |
credit_data = { | |
"user_key": request.user_key, | |
"user_name": user_info['user_id'], | |
"credit": user_info['credit'], | |
"expiration": datetime.fromisoformat(user_info['expiration']) | |
} | |
query_res = search_user_key(request.user_key) | |
user_info_db = client.user_info | |
user_keys_collection = user_info_db['user_keys'] | |
if not query_res: | |
user_keys_collection.insert_many([credit_data]) | |
else: | |
credit_data = query_res | |
if credit_data['credit'] < 1: | |
raise HTTPException( | |
status_code=503, | |
detail="当前user-key额度已用尽。" | |
) | |
# 查看客户是否提供了openai_key | |
if not request.openai_key and not user_key: | |
raise HTTPException( | |
status_code=503, | |
detail="当前用户无生图权限,请点击'想使用自己的OpenAI API Key?'输入您的OpenAI API Key或联系@wangyue161并添加白名单user-key后重试。" | |
) | |
# 提取设计类型 | |
design_type = f"设计类型:{request.text.split()[0]}\n" if len(request.text.split()) > 1 else "" | |
# 构建图像生成提示词 | |
prompt = f"{design_type}基于以下设计反馈优化UI设计: {', '.join(request.suggestions)}" | |
# 处理图片数据 | |
image_data = request.image_data | |
# 调用OpenAI图像编辑API | |
response = await call_openai_image_api( | |
image_data=image_data, | |
prompt=prompt, | |
request_model_id=request.request_model_id, | |
openai_key=request.openai_key | |
) | |
if user_key: | |
# Update credit count by decrementing it by 1 | |
user_info_db = client.user_info | |
user_keys_collection = user_info_db['user_keys'] | |
user_keys_collection.update_one( | |
{"user_key": request.user_key}, | |
{"$inc": {"credit": -1}} | |
) | |
return JSONResponse(response) | |
except HTTPException as he: | |
raise he | |
except Exception as e: | |
logging.error(f'Error: {e}, traceback: {traceback.format_exc()}') | |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}') | |
async def optimize_text(request: TextOptimizationRequest): | |
try: | |
response = await call_openai_api( | |
system_prompt="你是一个专业的文案优化助手,擅长将简单直接的反馈转换为礼貌、专业且保持原意的表达方式。", | |
user_content=[{ | |
"type": "input_text", | |
"text": f"原始反馈内容:{request.original_feedback}\n\n我想回复:{request.user_input}\n\n请优化我的回复内容,使其更加礼貌、专业,同时保持原始意思,增加一些共情和专业术语。" | |
}], | |
request_model_id=request.request_model_id | |
) | |
return JSONResponse(response) | |
except Exception as e: | |
logging.error(f'Error: {e}, traceback: {traceback.format_exc()}') | |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}') | |
async def analyze_case_study(request: CaseStudyRequest): | |
try: | |
response = await call_openai_api( | |
system_prompt="你是一个专业的案例分析助手,你需要根据用户需求进行案例分析。", | |
user_content=[{ | |
"type": "input_text", | |
"text": request.user_input | |
}], | |
request_model_id=request.request_model_id | |
) | |
return JSONResponse(response) | |
except Exception as e: | |
logging.error(f'Error: {e}, traceback: {traceback.format_exc()}') | |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}') | |
async def search_design_examples(request: SearchRequest): | |
try: | |
# 构建搜索查询 | |
search_query = f"{request.query} UI设计" | |
# 调用Google Custom Search API | |
async with aiohttp.ClientSession() as session: | |
async with session.get( | |
"https://customsearch.googleapis.com/customsearch/v1", | |
params={ | |
"key": GOOGLE_API_KEY, | |
"q": search_query, | |
"cx": SEARCH_ENGINE_ID, | |
"num": request.num_results | |
} | |
) as response: | |
if response.status != 200: | |
raise HTTPException(status_code=response.status, detail="Google Search API调用失败") | |
search_data = await response.json() | |
if not search_data.get("items"): | |
return JSONResponse({"items": []}) | |
# 处理搜索结果 | |
results = [] | |
for item in search_data["items"]: | |
result = { | |
"title": item["title"].replace("</?b>", ""), | |
"link": item["link"], | |
"snippet": item.get("snippet", ""), | |
"image": None | |
} | |
# 尝试获取图片URL | |
if "pagemap" in item: | |
if "cse_image" in item["pagemap"]: | |
result["image"] = item["pagemap"]["cse_image"][0]["src"] | |
elif "cse_thumbnail" in item["pagemap"]: | |
result["image"] = item["pagemap"]["cse_thumbnail"][0]["src"] | |
# 如果没有图片,使用截图服务 | |
if not result["image"]: | |
try: | |
screenshot_response = await capture_screenshot(ScreenshotRequest( | |
url=result["link"], | |
width=1024, | |
height=768, | |
format="png" | |
)) | |
if isinstance(screenshot_response, dict) and "imageUrl" in screenshot_response: | |
result["image"] = screenshot_response["imageUrl"] | |
except Exception as e: | |
print(f"获取截图失败: {str(e)}") | |
# 使用默认图片 | |
result["image"] = "https://img.freepik.com/free-vector/gradient-ui-ux-background_23-2149052117.jpg" | |
results.append(result) | |
return JSONResponse({"items": results}) | |
except Exception as e: | |
logging.error(f'Error: {e}, traceback: {traceback.format_exc()}') | |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}') | |
async def call_openai_api(system_prompt: str, user_content: List[Dict], request_model_id='gpt-4.1-nano'): | |
headers = { | |
"Authorization": f"Bearer {OPENAI_API_KEY}", | |
"Content-Type": "application/json" | |
} | |
data = { | |
"model": request_model_id, | |
"input": [ | |
{ | |
"role": "system", | |
"content": [{"type": "input_text", "text": system_prompt}] | |
}, | |
{ | |
"role": "user", | |
"content": user_content | |
} | |
] | |
} | |
async with aiohttp.ClientSession() as session: | |
async with session.post("https://api.openai.com/v1/responses", headers=headers, json=data) as response: | |
if response.status != 200: | |
resp = await response.json() | |
logging.error(f'response: {resp}') | |
raise HTTPException(status_code=response.status, detail=f"OpenAI API调用失败, response: {resp}") | |
return await response.json() | |
async def call_openai_image_api(image_data: str, prompt: str, request_model_id='gpt-image-1', openai_key=''): | |
try: | |
# 从base64字符串中提取纯base64数据(如果包含前缀) | |
if image_data and 'base64,' in image_data: | |
image_data = image_data.split('base64,')[1] | |
logging.log(logging.DEBUG, f"Processing image data (first 100 chars): {image_data[:100]}") | |
# 将base64图片数据转换为文件对象 | |
image_bytes = base64.b64decode(image_data) | |
image_file = BytesIO(image_bytes) | |
# 如果是dall-e-2模型,需要将图片调整为800x800 | |
if request_model_id == 'dall-e-2': | |
# 打开图片 | |
img = Image.open(image_file) | |
# 调整图片大小为800x800,使用LANCZOS重采样方法以获得更好的质量 | |
img = img.resize((800, 800), Image.Resampling.LANCZOS) | |
# 创建新的BytesIO对象 | |
image_file = BytesIO() | |
# 保存调整后的图片 | |
img.save(image_file, format='PNG') | |
# 将文件指针移到开始位置 | |
image_file.seek(0) | |
image_file.name = "original-design.png" # 设置文件名,与JS代码一致 | |
# 创建OpenAI客户端 | |
client = OpenAI(api_key=openai_key if openai_key else OPENAI_API_IMAGE_EDIT_KEY) | |
# 调用图像编辑API | |
response = client.images.edit( | |
model=request_model_id, | |
image=image_file, | |
prompt=prompt | |
) | |
# 获取生成的图片数据 | |
if not response.data or len(response.data) == 0: | |
raise ValueError("No image data returned from API") | |
image_result = response.data[0] | |
# 返回与JS代码一致的格式 | |
return { | |
"data": [{ | |
"url": f"data:image/png;base64,{image_result.b64_json}", | |
"b64_json": image_result.b64_json | |
}] | |
} | |
except Exception as e: | |
logging.error(f'Error in call_openai_image_api: {e}, traceback: {traceback.format_exc()}') | |
raise HTTPException( | |
status_code=500, | |
detail=f'Error processing image: {str(e)}' | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
app, | |
host="0.0.0.0", | |
port=7860, | |
limit_concurrency=MAX_CONNECTIONS | |
) | |