Junhui Ji commited on
Commit
6e95ff1
·
1 Parent(s): 876d086

update whitelist, dalle

Browse files
Files changed (2) hide show
  1. main.py +101 -7
  2. static/script.js +3 -3
main.py CHANGED
@@ -1,4 +1,4 @@
1
- from fastapi import FastAPI, HTTPException
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
  from fastapi.responses import JSONResponse
@@ -16,10 +16,41 @@ import aiohttp
16
  import traceback
17
  import requests
18
  from openai import OpenAI
 
 
 
 
19
 
20
 
21
  app = FastAPI()
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # 确保缓存目录存在
24
  CACHE_DIR = "cache"
25
  os.makedirs(CACHE_DIR, exist_ok=True)
@@ -138,7 +169,11 @@ async def capture_screenshot(request: ScreenshotRequest):
138
 
139
  @app.get("/health")
140
  async def health_check():
141
- return {"status": "ok"}
 
 
 
 
142
 
143
 
144
  @app.get("/")
@@ -169,10 +204,43 @@ async def analyze_feedback(request: AnalysisRequest):
169
  raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}')
170
 
171
  @app.post("/api/optimize-design")
172
- async def optimize_design(request: OptimizationRequest):
173
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  # 构建图像生成提示词
175
- prompt = f"基于以下设计反馈优化UI设计: {', '.join(request.suggestions)}"
176
 
177
  # 处理图片数据
178
  image_data = request.image_data
@@ -185,9 +253,15 @@ async def optimize_design(request: OptimizationRequest):
185
  )
186
 
187
  return JSONResponse(response)
 
 
188
  except Exception as e:
189
  logging.error(f'Error: {e}, traceback: {traceback.format_exc()}')
190
  raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}')
 
 
 
 
191
 
192
  @app.post("/api/optimize-text")
193
  async def optimize_text(request: TextOptimizationRequest):
@@ -324,11 +398,25 @@ async def call_openai_image_api(image_data: str, prompt: str, request_model_id='
324
  if image_data and 'base64,' in image_data:
325
  image_data = image_data.split('base64,')[1]
326
 
327
- logging.log(logging.INFO, f"Processing image data (first 100 chars): {image_data[:100]}")
328
 
329
  # 将base64图片数据转换为文件对象
330
  image_bytes = base64.b64decode(image_data)
331
  image_file = BytesIO(image_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  image_file.name = "original-design.png" # 设置文件名,与JS代码一致
333
 
334
  # 创建OpenAI客户端
@@ -338,7 +426,7 @@ async def call_openai_image_api(image_data: str, prompt: str, request_model_id='
338
  response = client.images.edit(
339
  model=request_model_id,
340
  image=image_file,
341
- prompt=prompt # 明确要求返回base64格式
342
  )
343
 
344
  # 获取生成的图片数据
@@ -364,4 +452,10 @@ async def call_openai_image_api(image_data: str, prompt: str, request_model_id='
364
 
365
  if __name__ == "__main__":
366
  import uvicorn
367
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
  from fastapi.responses import JSONResponse
 
16
  import traceback
17
  import requests
18
  from openai import OpenAI
19
+ from starlette.middleware.base import BaseHTTPMiddleware
20
+ import uvicorn
21
+ from collections import defaultdict
22
+ from PIL import Image
23
 
24
 
25
  app = FastAPI()
26
 
27
+ # 设置最大连接数
28
+ MAX_CONNECTIONS = 10
29
+ current_connections = 0
30
+
31
+ # 设置优化设计接口的访问限制
32
+ optimize_design_requests = defaultdict(int) # 记录每个IP的请求次数
33
+ optimize_design_timestamps = defaultdict(float) # 记录每个IP的首次请求时间
34
+ white_list = eval(os.getenv("WHITELIST"))
35
+
36
+ class ConnectionLimitMiddleware(BaseHTTPMiddleware):
37
+ async def dispatch(self, request: Request, call_next):
38
+ global current_connections
39
+ if current_connections >= MAX_CONNECTIONS:
40
+ return JSONResponse(
41
+ status_code=503,
42
+ content={"detail": "已超过最大链接数,请稍后重试"}
43
+ )
44
+ current_connections += 1
45
+ try:
46
+ response = await call_next(request)
47
+ return response
48
+ finally:
49
+ current_connections -= 1
50
+
51
+ # 添加中间件
52
+ app.add_middleware(ConnectionLimitMiddleware)
53
+
54
  # 确保缓存目录存在
55
  CACHE_DIR = "cache"
56
  os.makedirs(CACHE_DIR, exist_ok=True)
 
169
 
170
  @app.get("/health")
171
  async def health_check():
172
+ return {
173
+ "status": "ok",
174
+ "current_connections": current_connections,
175
+ "max_connections": MAX_CONNECTIONS
176
+ }
177
 
178
 
179
  @app.get("/")
 
204
  raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}')
205
 
206
  @app.post("/api/optimize-design")
207
+ async def optimize_design(request: OptimizationRequest, client_ip: str = None):
208
  try:
209
+ # 获取客户端IP(如果未提供,使用默认值)
210
+ if client_ip is None or client_ip not in white_list:
211
+ raise HTTPException(
212
+ status_code=503,
213
+ detail="当前用户无生图权限,请联系@王月(Phoebe)添加白名单后重试。"
214
+ )
215
+
216
+ user_rate_limit = white_list[client_ip]
217
+
218
+ current_time = time.time()
219
+
220
+ # 检查是否需要重置计数器(超过24小时)
221
+ if current_time - optimize_design_timestamps[client_ip] > 3600*24:
222
+ optimize_design_requests[client_ip] = 0
223
+ optimize_design_timestamps[client_ip] = current_time
224
+
225
+ # 如果是首次请求,记录时间戳
226
+ if optimize_design_requests[client_ip] == 0:
227
+ optimize_design_timestamps[client_ip] = current_time
228
+
229
+ # 检查是否超过限制
230
+ if optimize_design_requests[client_ip] >= user_rate_limit:
231
+ raise HTTPException(
232
+ status_code=503,
233
+ detail="用户当日改图接口访问已达上限,请24小时后重试"
234
+ )
235
+
236
+ # 增加请求计数
237
+ optimize_design_requests[client_ip] += 1
238
+
239
+ # 提取设计类型
240
+ design_type = f"设计类型:{request.text.split()[0]}\n" if len(request.text.split()) > 1 else ""
241
+
242
  # 构建图像生成提示词
243
+ prompt = f"{design_type}基于以下设计反馈优化UI设计: {', '.join(request.suggestions)}"
244
 
245
  # 处理图片数据
246
  image_data = request.image_data
 
253
  )
254
 
255
  return JSONResponse(response)
256
+ except HTTPException as he:
257
+ raise he
258
  except Exception as e:
259
  logging.error(f'Error: {e}, traceback: {traceback.format_exc()}')
260
  raise HTTPException(status_code=500, detail=f'Error: {e}, traceback: {traceback.format_exc()}')
261
+ finally:
262
+ # 如果发生异常,减少请求计数
263
+ if 'he' in locals() and isinstance(he, HTTPException):
264
+ optimize_design_requests[client_ip] -= 1
265
 
266
  @app.post("/api/optimize-text")
267
  async def optimize_text(request: TextOptimizationRequest):
 
398
  if image_data and 'base64,' in image_data:
399
  image_data = image_data.split('base64,')[1]
400
 
401
+ logging.log(logging.DEBUG, f"Processing image data (first 100 chars): {image_data[:100]}")
402
 
403
  # 将base64图片数据转换为文件对象
404
  image_bytes = base64.b64decode(image_data)
405
  image_file = BytesIO(image_bytes)
406
+
407
+ # 如果是dall-e-2模型,需要将图片调整为800x800
408
+ if request_model_id == 'dall-e-2':
409
+ # 打开图片
410
+ img = Image.open(image_file)
411
+ # 调整图片大小为800x800,使用LANCZOS重采样方法以获得更好的质量
412
+ img = img.resize((800, 800), Image.Resampling.LANCZOS)
413
+ # 创建新的BytesIO对象
414
+ image_file = BytesIO()
415
+ # 保存调整后的图片
416
+ img.save(image_file, format='PNG')
417
+ # 将文件指针移到开始位置
418
+ image_file.seek(0)
419
+
420
  image_file.name = "original-design.png" # 设置文件名,与JS代码一致
421
 
422
  # 创建OpenAI客户端
 
426
  response = client.images.edit(
427
  model=request_model_id,
428
  image=image_file,
429
+ prompt=prompt
430
  )
431
 
432
  # 获取生成的图片数据
 
452
 
453
  if __name__ == "__main__":
454
  import uvicorn
455
+ uvicorn.run(
456
+ app,
457
+ host="0.0.0.0",
458
+ port=7860,
459
+ limit_concurrency=MAX_CONNECTIONS,
460
+ limit_max_requests=0 # 0 means no limit on total requests
461
+ )
static/script.js CHANGED
@@ -762,8 +762,8 @@ document.addEventListener('DOMContentLoaded', function() {
762
  }
763
 
764
  // 构建提示词
765
- const prompt = `基于以下设计反馈优化UI设计: ${suggestions.join(', ')}`;
766
-
767
  // 调用后端API
768
  const response = await fetch(`${BASE_URL}api/optimize-design`, {
769
  method: 'POST',
@@ -771,7 +771,7 @@ document.addEventListener('DOMContentLoaded', function() {
771
  'Content-Type': 'application/json'
772
  },
773
  body: JSON.stringify({
774
- text: analysisResult,
775
  image_data: uploadedImage,
776
  suggestions: suggestions,
777
  request_model_id: 'dall-e-2'
 
762
  }
763
 
764
  // 构建提示词
765
+ const uploadedText = sessionStorage.getItem('uploadedText') || '';
766
+
767
  // 调用后端API
768
  const response = await fetch(`${BASE_URL}api/optimize-design`, {
769
  method: 'POST',
 
771
  'Content-Type': 'application/json'
772
  },
773
  body: JSON.stringify({
774
+ text: uploadedText,
775
  image_data: uploadedImage,
776
  suggestions: suggestions,
777
  request_model_id: 'dall-e-2'