Spaces:
Running
Running
Junhui Ji
commited on
Commit
·
6e95ff1
1
Parent(s):
876d086
update whitelist, dalle
Browse files- main.py +101 -7
- static/script.js +3 -3
main.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from fastapi import FastAPI, HTTPException
|
2 |
from fastapi.staticfiles import StaticFiles
|
3 |
from fastapi.responses import FileResponse
|
4 |
from fastapi.responses import JSONResponse
|
@@ -16,10 +16,41 @@ import aiohttp
|
|
16 |
import traceback
|
17 |
import requests
|
18 |
from openai import OpenAI
|
|
|
|
|
|
|
|
|
19 |
|
20 |
|
21 |
app = FastAPI()
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
# 确保缓存目录存在
|
24 |
CACHE_DIR = "cache"
|
25 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
@@ -138,7 +169,11 @@ async def capture_screenshot(request: ScreenshotRequest):
|
|
138 |
|
139 |
@app.get("/health")
|
140 |
async def health_check():
|
141 |
-
return {
|
|
|
|
|
|
|
|
|
142 |
|
143 |
|
144 |
@app.get("/")
|
@@ -169,10 +204,43 @@ async def analyze_feedback(request: AnalysisRequest):
|
|
169 |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}')
|
170 |
|
171 |
@app.post("/api/optimize-design")
|
172 |
-
async def optimize_design(request: OptimizationRequest):
|
173 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
# 构建图像生成提示词
|
175 |
-
prompt = f"基于以下设计反馈优化UI设计: {', '.join(request.suggestions)}"
|
176 |
|
177 |
# 处理图片数据
|
178 |
image_data = request.image_data
|
@@ -185,9 +253,15 @@ async def optimize_design(request: OptimizationRequest):
|
|
185 |
)
|
186 |
|
187 |
return JSONResponse(response)
|
|
|
|
|
188 |
except Exception as e:
|
189 |
logging.error(f'Error: {e}, traceback: {traceback.format_exc()}')
|
190 |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}')
|
|
|
|
|
|
|
|
|
191 |
|
192 |
@app.post("/api/optimize-text")
|
193 |
async def optimize_text(request: TextOptimizationRequest):
|
@@ -324,11 +398,25 @@ async def call_openai_image_api(image_data: str, prompt: str, request_model_id='
|
|
324 |
if image_data and 'base64,' in image_data:
|
325 |
image_data = image_data.split('base64,')[1]
|
326 |
|
327 |
-
logging.log(logging.
|
328 |
|
329 |
# 将base64图片数据转换为文件对象
|
330 |
image_bytes = base64.b64decode(image_data)
|
331 |
image_file = BytesIO(image_bytes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
image_file.name = "original-design.png" # 设置文件名,与JS代码一致
|
333 |
|
334 |
# 创建OpenAI客户端
|
@@ -338,7 +426,7 @@ async def call_openai_image_api(image_data: str, prompt: str, request_model_id='
|
|
338 |
response = client.images.edit(
|
339 |
model=request_model_id,
|
340 |
image=image_file,
|
341 |
-
prompt=prompt
|
342 |
)
|
343 |
|
344 |
# 获取生成的图片数据
|
@@ -364,4 +452,10 @@ async def call_openai_image_api(image_data: str, prompt: str, request_model_id='
|
|
364 |
|
365 |
if __name__ == "__main__":
|
366 |
import uvicorn
|
367 |
-
uvicorn.run(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException, Request
|
2 |
from fastapi.staticfiles import StaticFiles
|
3 |
from fastapi.responses import FileResponse
|
4 |
from fastapi.responses import JSONResponse
|
|
|
16 |
import traceback
|
17 |
import requests
|
18 |
from openai import OpenAI
|
19 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
20 |
+
import uvicorn
|
21 |
+
from collections import defaultdict
|
22 |
+
from PIL import Image
|
23 |
|
24 |
|
25 |
app = FastAPI()
|
26 |
|
27 |
+
# 设置最大连接数
|
28 |
+
MAX_CONNECTIONS = 10
|
29 |
+
current_connections = 0
|
30 |
+
|
31 |
+
# 设置优化设计接口的访问限制
|
32 |
+
optimize_design_requests = defaultdict(int) # 记录每个IP的请求次数
|
33 |
+
optimize_design_timestamps = defaultdict(float) # 记录每个IP的首次请求时间
|
34 |
+
white_list = eval(os.getenv("WHITELIST"))
|
35 |
+
|
36 |
+
class ConnectionLimitMiddleware(BaseHTTPMiddleware):
|
37 |
+
async def dispatch(self, request: Request, call_next):
|
38 |
+
global current_connections
|
39 |
+
if current_connections >= MAX_CONNECTIONS:
|
40 |
+
return JSONResponse(
|
41 |
+
status_code=503,
|
42 |
+
content={"detail": "已超过最大链接数,请稍后重试"}
|
43 |
+
)
|
44 |
+
current_connections += 1
|
45 |
+
try:
|
46 |
+
response = await call_next(request)
|
47 |
+
return response
|
48 |
+
finally:
|
49 |
+
current_connections -= 1
|
50 |
+
|
51 |
+
# 添加中间件
|
52 |
+
app.add_middleware(ConnectionLimitMiddleware)
|
53 |
+
|
54 |
# 确保缓存目录存在
|
55 |
CACHE_DIR = "cache"
|
56 |
os.makedirs(CACHE_DIR, exist_ok=True)
|
|
|
169 |
|
170 |
@app.get("/health")
|
171 |
async def health_check():
|
172 |
+
return {
|
173 |
+
"status": "ok",
|
174 |
+
"current_connections": current_connections,
|
175 |
+
"max_connections": MAX_CONNECTIONS
|
176 |
+
}
|
177 |
|
178 |
|
179 |
@app.get("/")
|
|
|
204 |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}')
|
205 |
|
206 |
@app.post("/api/optimize-design")
|
207 |
+
async def optimize_design(request: OptimizationRequest, client_ip: str = None):
|
208 |
try:
|
209 |
+
# 获取客户端IP(如果未提供,使用默认值)
|
210 |
+
if client_ip is None or client_ip not in white_list:
|
211 |
+
raise HTTPException(
|
212 |
+
status_code=503,
|
213 |
+
detail="当前用户无生图权限,请联系@王月(Phoebe)添加白名单后重试。"
|
214 |
+
)
|
215 |
+
|
216 |
+
user_rate_limit = white_list[client_ip]
|
217 |
+
|
218 |
+
current_time = time.time()
|
219 |
+
|
220 |
+
# 检查是否需要重置计数器(超过24小时)
|
221 |
+
if current_time - optimize_design_timestamps[client_ip] > 3600*24:
|
222 |
+
optimize_design_requests[client_ip] = 0
|
223 |
+
optimize_design_timestamps[client_ip] = current_time
|
224 |
+
|
225 |
+
# 如果是首次请求,记录时间戳
|
226 |
+
if optimize_design_requests[client_ip] == 0:
|
227 |
+
optimize_design_timestamps[client_ip] = current_time
|
228 |
+
|
229 |
+
# 检查是否超过限制
|
230 |
+
if optimize_design_requests[client_ip] >= user_rate_limit:
|
231 |
+
raise HTTPException(
|
232 |
+
status_code=503,
|
233 |
+
detail="用户当日改图接口访问已达上限,请24小时后重试"
|
234 |
+
)
|
235 |
+
|
236 |
+
# 增加请求计数
|
237 |
+
optimize_design_requests[client_ip] += 1
|
238 |
+
|
239 |
+
# 提取设计类型
|
240 |
+
design_type = f"设计类型:{request.text.split()[0]}\n" if len(request.text.split()) > 1 else ""
|
241 |
+
|
242 |
# 构建图像生成提示词
|
243 |
+
prompt = f"{design_type}基于以下设计反馈优化UI设计: {', '.join(request.suggestions)}"
|
244 |
|
245 |
# 处理图片数据
|
246 |
image_data = request.image_data
|
|
|
253 |
)
|
254 |
|
255 |
return JSONResponse(response)
|
256 |
+
except HTTPException as he:
|
257 |
+
raise he
|
258 |
except Exception as e:
|
259 |
logging.error(f'Error: {e}, traceback: {traceback.format_exc()}')
|
260 |
raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}')
|
261 |
+
finally:
|
262 |
+
# 如果发生异常,减少请求计数
|
263 |
+
if 'he' in locals() and isinstance(he, HTTPException):
|
264 |
+
optimize_design_requests[client_ip] -= 1
|
265 |
|
266 |
@app.post("/api/optimize-text")
|
267 |
async def optimize_text(request: TextOptimizationRequest):
|
|
|
398 |
if image_data and 'base64,' in image_data:
|
399 |
image_data = image_data.split('base64,')[1]
|
400 |
|
401 |
+
logging.log(logging.DEBUG, f"Processing image data (first 100 chars): {image_data[:100]}")
|
402 |
|
403 |
# 将base64图片数据转换为文件对象
|
404 |
image_bytes = base64.b64decode(image_data)
|
405 |
image_file = BytesIO(image_bytes)
|
406 |
+
|
407 |
+
# 如果是dall-e-2模型,需要将图片调整为800x800
|
408 |
+
if request_model_id == 'dall-e-2':
|
409 |
+
# 打开图片
|
410 |
+
img = Image.open(image_file)
|
411 |
+
# 调整图片大小为800x800,使用LANCZOS重采样方法以获得更好的质量
|
412 |
+
img = img.resize((800, 800), Image.Resampling.LANCZOS)
|
413 |
+
# 创建新的BytesIO对象
|
414 |
+
image_file = BytesIO()
|
415 |
+
# 保存调整后的图片
|
416 |
+
img.save(image_file, format='PNG')
|
417 |
+
# 将文件指针移到开始位置
|
418 |
+
image_file.seek(0)
|
419 |
+
|
420 |
image_file.name = "original-design.png" # 设置文件名,与JS代码一致
|
421 |
|
422 |
# 创建OpenAI客户端
|
|
|
426 |
response = client.images.edit(
|
427 |
model=request_model_id,
|
428 |
image=image_file,
|
429 |
+
prompt=prompt
|
430 |
)
|
431 |
|
432 |
# 获取生成的图片数据
|
|
|
452 |
|
453 |
if __name__ == "__main__":
|
454 |
import uvicorn
|
455 |
+
uvicorn.run(
|
456 |
+
app,
|
457 |
+
host="0.0.0.0",
|
458 |
+
port=7860,
|
459 |
+
limit_concurrency=MAX_CONNECTIONS,
|
460 |
+
limit_max_requests=0 # 0 means no limit on total requests
|
461 |
+
)
|
static/script.js
CHANGED
@@ -762,8 +762,8 @@ document.addEventListener('DOMContentLoaded', function() {
|
|
762 |
}
|
763 |
|
764 |
// 构建提示词
|
765 |
-
const
|
766 |
-
|
767 |
// 调用后端API
|
768 |
const response = await fetch(`${BASE_URL}api/optimize-design`, {
|
769 |
method: 'POST',
|
@@ -771,7 +771,7 @@ document.addEventListener('DOMContentLoaded', function() {
|
|
771 |
'Content-Type': 'application/json'
|
772 |
},
|
773 |
body: JSON.stringify({
|
774 |
-
text:
|
775 |
image_data: uploadedImage,
|
776 |
suggestions: suggestions,
|
777 |
request_model_id: 'dall-e-2'
|
|
|
762 |
}
|
763 |
|
764 |
// 构建提示词
|
765 |
+
const uploadedText = sessionStorage.getItem('uploadedText') || '';
|
766 |
+
|
767 |
// 调用后端API
|
768 |
const response = await fetch(`${BASE_URL}api/optimize-design`, {
|
769 |
method: 'POST',
|
|
|
771 |
'Content-Type': 'application/json'
|
772 |
},
|
773 |
body: JSON.stringify({
|
774 |
+
text: uploadedText,
|
775 |
image_data: uploadedImage,
|
776 |
suggestions: suggestions,
|
777 |
request_model_id: 'dall-e-2'
|