soiz1 commited on
Commit
5dd0a3e
·
verified ·
1 Parent(s): db59c04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -61
app.py CHANGED
@@ -1,27 +1,65 @@
1
- import cv2
2
- import gradio as gr
3
  import os
 
 
4
  from PIL import Image
5
  import numpy as np
6
  import torch
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
10
- import gdown
11
- import matplotlib.pyplot as plt
12
  import warnings
13
  warnings.filterwarnings("ignore")
14
- import flask
15
 
16
- os.system("git clone https://github.com/xuebinqin/DIS")
17
- os.system("mv DIS/IS-Net/* .")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # project imports
20
- from data_loader_cache import normalize, im_reader, im_preprocess
21
- from models import *
22
 
23
- # モデルと設定の初期化
24
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  class GOSNormalize(object):
27
  def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
@@ -43,13 +81,16 @@ def load_image(im_path, hypar):
43
 
44
  def build_model(hypar, device):
45
  net = hypar["model"]
46
- if(hypar["model_digit"]=="half"):
 
47
  net.half()
48
  for layer in net.modules():
49
  if isinstance(layer, nn.BatchNorm2d):
50
  layer.float()
 
51
  net.to(device)
52
- if(hypar["restore_model"]!=""):
 
53
  net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
54
  net.to(device)
55
  net.eval()
@@ -57,11 +98,12 @@ def build_model(hypar, device):
57
 
58
  def predict(net, inputs_val, shapes_val, hypar, device):
59
  net.eval()
60
- if(hypar["model_digit"]=="full"):
 
61
  inputs_val = inputs_val.type(torch.FloatTensor)
62
  else:
63
  inputs_val = inputs_val.type(torch.HalfTensor)
64
-
65
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
66
  ds_val = net(inputs_val_v)[0]
67
  pred_val = ds_val[0][0,:,:,:]
@@ -75,70 +117,73 @@ def predict(net, inputs_val, shapes_val, hypar, device):
75
  if device == 'cuda': torch.cuda.empty_cache()
76
  return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
77
 
78
- # パラメータ設定
79
- hypar = {
80
- "model_path": "./saved_models",
81
- "restore_model": "isnet.pth",
82
- "interm_sup": False,
83
- "model_digit": "full",
84
- "seed": 0,
85
- "cache_size": [1024, 1024],
86
- "input_size": [1024, 1024],
87
- "crop_size": [1024, 1024],
88
- "model": ISNetDIS()
89
- }
90
-
91
- # モデルをビルド
92
- net = build_model(hypar, device)
93
-
94
- app = Flask(__name__)
95
- app.config['UPLOAD_FOLDER'] = 'uploads'
96
- os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
97
 
98
- @app.route('/api/remove-background', methods=['POST'])
99
- def remove_background():
100
- if 'file' not in request.files:
101
- return jsonify({"error": "No file provided"}), 400
102
 
103
- file = request.files['file']
104
  if file.filename == '':
105
  return jsonify({"error": "No selected file"}), 400
106
 
 
 
 
107
  # ファイルを保存
108
- filename = secure_filename(file.filename)
109
- filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
110
- file.save(filepath)
111
 
112
  try:
113
  # 画像処理
114
- image_tensor, orig_size = load_image(filepath, hypar)
115
  mask = predict(net, image_tensor, orig_size, hypar, device)
116
 
 
 
 
 
 
117
  pil_mask = Image.fromarray(mask).convert('L')
118
- im_rgb = Image.open(filepath).convert("RGB")
119
  im_rgba = im_rgb.copy()
120
  im_rgba.putalpha(pil_mask)
121
 
122
- # 結果をバイトデータとして返す
123
- output_buffer = io.BytesIO()
124
- im_rgba.save(output_buffer, format="PNG")
125
- output_buffer.seek(0)
126
 
127
- # 一時ファイルを削除
128
- os.remove(filepath)
129
-
130
- return send_file(
131
- output_buffer,
132
- mimetype='image/png',
133
- as_attachment=True,
134
- download_name='output.png'
135
- )
136
  except Exception as e:
137
  return jsonify({"error": str(e)}), 500
138
 
139
- @app.route('/api/health', methods=['GET'])
140
- def health_check():
141
- return jsonify({"status": "healthy"}), 200
 
 
 
 
142
 
143
  if __name__ == '__main__':
144
  app.run(host='0.0.0.0', port=5000, debug=True)
 
 
 
1
  import os
2
+ import cv2
3
+ import shutil
4
  from PIL import Image
5
  import numpy as np
6
  import torch
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
10
+ from flask import Flask, request, jsonify, render_template, send_from_directory
 
11
  import warnings
12
  warnings.filterwarnings("ignore")
 
13
 
14
+ app = Flask(__name__)
15
+
16
+ # 一時ファイル保存用ディレクトリ
17
+ UPLOAD_FOLDER = 'uploads'
18
+ RESULT_FOLDER = 'results'
19
+ EXAMPLES_FOLDER = 'examples'
20
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
21
+ os.makedirs(RESULT_FOLDER, exist_ok=True)
22
+ os.makedirs(EXAMPLES_FOLDER, exist_ok=True)
23
+
24
+ # モデル関連のインポートと初期化
25
+ def initialize_model():
26
+ # Clean up previous installations
27
+ if os.path.exists("DIS"):
28
+ shutil.rmtree("DIS")
29
+ if os.path.exists("saved_models"):
30
+ shutil.rmtree("saved_models")
31
+
32
+ # Clone repository and setup model
33
+ os.system("git clone https://github.com/xuebinqin/DIS")
34
+ os.system("mv DIS/IS-Net/* .")
35
+
36
+ # Import after setup
37
+ from data_loader_cache import normalize, im_reader, im_preprocess
38
+ from models import ISNetDIS
39
 
40
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
41
 
42
+ # Setup model directories
43
+ if not os.path.exists("saved_models"):
44
+ os.mkdir("saved_models")
45
+ os.system("mv isnet.pth saved_models/")
46
+
47
+ # Set Parameters
48
+ hypar = {
49
+ "model_path": "./saved_models",
50
+ "restore_model": "isnet.pth",
51
+ "interm_sup": False,
52
+ "model_digit": "full",
53
+ "seed": 0,
54
+ "cache_size": [1024, 1024],
55
+ "input_size": [1024, 1024],
56
+ "crop_size": [1024, 1024],
57
+ "model": ISNetDIS()
58
+ }
59
+
60
+ # Build Model
61
+ net = build_model(hypar, device)
62
+ return net, hypar, device
63
 
64
  class GOSNormalize(object):
65
  def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
 
81
 
82
  def build_model(hypar, device):
83
  net = hypar["model"]
84
+
85
+ if hypar["model_digit"] == "half":
86
  net.half()
87
  for layer in net.modules():
88
  if isinstance(layer, nn.BatchNorm2d):
89
  layer.float()
90
+
91
  net.to(device)
92
+
93
+ if hypar["restore_model"] != "":
94
  net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"], map_location=device))
95
  net.to(device)
96
  net.eval()
 
98
 
99
  def predict(net, inputs_val, shapes_val, hypar, device):
100
  net.eval()
101
+
102
+ if hypar["model_digit"] == "full":
103
  inputs_val = inputs_val.type(torch.FloatTensor)
104
  else:
105
  inputs_val = inputs_val.type(torch.HalfTensor)
106
+
107
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
108
  ds_val = net(inputs_val_v)[0]
109
  pred_val = ds_val[0][0,:,:,:]
 
117
  if device == 'cuda': torch.cuda.empty_cache()
118
  return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
119
 
120
+ @app.route('/')
121
+ def index():
122
+ return render_template('index.html')
123
+
124
+ @app.route('/examples/<filename>')
125
+ def serve_example(filename):
126
+ # サンプル画像がなければダウンロード
127
+ example_path = os.path.join(EXAMPLES_FOLDER, filename)
128
+ if not os.path.exists(example_path):
129
+ if filename == 'robot.png':
130
+ os.system(f"wget https://raw.githubusercontent.com/xuebinqin/DIS/main/IS-Net/robot.png -O {example_path}")
131
+ elif filename == 'ship.png':
132
+ os.system(f"wget https://raw.githubusercontent.com/xuebinqin/DIS/main/IS-Net/ship.png -O {example_path}")
133
+
134
+ return send_from_directory(EXAMPLES_FOLDER, filename)
 
 
 
 
135
 
136
+ @app.route('/api/process', methods=['POST'])
137
+ def process_image():
138
+ if 'image' not in request.files:
139
+ return jsonify({"error": "No image provided"}), 400
140
 
141
+ file = request.files['image']
142
  if file.filename == '':
143
  return jsonify({"error": "No selected file"}), 400
144
 
145
+ # 毎回モデルを初期化
146
+ net, hypar, device = initialize_model()
147
+
148
  # ファイルを保存
149
+ upload_path = os.path.join(UPLOAD_FOLDER, file.filename)
150
+ file.save(upload_path)
 
151
 
152
  try:
153
  # 画像処理
154
+ image_tensor, orig_size = load_image(upload_path, hypar)
155
  mask = predict(net, image_tensor, orig_size, hypar, device)
156
 
157
+ # 結果を保存
158
+ original_filename = os.path.splitext(file.filename)[0]
159
+ result_rgba_path = os.path.join(RESULT_FOLDER, f"{original_filename}_rgba.png")
160
+ result_mask_path = os.path.join(RESULT_FOLDER, f"{original_filename}_mask.png")
161
+
162
  pil_mask = Image.fromarray(mask).convert('L')
163
+ im_rgb = Image.open(upload_path).convert("RGB")
164
  im_rgba = im_rgb.copy()
165
  im_rgba.putalpha(pil_mask)
166
 
167
+ im_rgba.save(result_rgba_path)
168
+ pil_mask.save(result_mask_path)
 
 
169
 
170
+ # 結果のURLを返す
171
+ return jsonify({
172
+ "original": f"/{UPLOAD_FOLDER}/{file.filename}",
173
+ "rgba": f"/{RESULT_FOLDER}/{original_filename}_rgba.png",
174
+ "mask": f"/{RESULT_FOLDER}/{original_filename}_mask.png",
175
+ "filename": file.filename
176
+ })
 
 
177
  except Exception as e:
178
  return jsonify({"error": str(e)}), 500
179
 
180
+ @app.route(f'/{UPLOAD_FOLDER}/<filename>')
181
+ def serve_upload(filename):
182
+ return send_from_directory(UPLOAD_FOLDER, filename)
183
+
184
+ @app.route(f'/{RESULT_FOLDER}/<filename>')
185
+ def serve_result(filename):
186
+ return send_from_directory(RESULT_FOLDER, filename)
187
 
188
  if __name__ == '__main__':
189
  app.run(host='0.0.0.0', port=5000, debug=True)