dhanushreddy29 commited on
Commit
5727d71
·
1 Parent(s): f75b8b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -117
app.py CHANGED
@@ -9,130 +9,126 @@ from torch.autograd import Variable
9
  from PIL import Image
10
 
11
 
12
- def build_model(hypar, device):
13
- net = hypar["model"] # GOSNETINC(3,1)
14
-
15
- # convert to half precision
16
- if hypar["model_digit"] == "half":
17
- net.half()
18
- for layer in net.modules():
19
- if isinstance(layer, nn.BatchNorm2d):
20
- layer.float()
21
-
22
- net.to(device)
23
-
24
- if hypar["restore_model"] != "":
25
- net.load_state_dict(
26
- torch.load(
27
- hypar["model_path"] + "/" + hypar["restore_model"],
28
- map_location=device,
29
- )
30
- )
31
- net.to(device)
32
- net.eval()
33
- return net
34
-
35
-
36
- if not os.path.exists("saved_models"):
37
- os.mkdir("saved_models")
38
- os.mkdir("git")
39
- os.system("git clone https://github.com/xuebinqin/DIS git/xuebinqin/DIS")
40
- hf_hub_download(
41
- repo_id="NimaBoscarino/IS-Net_DIS-general-use",
42
- filename="isnet-general-use.pth",
43
- local_dir="saved_models",
44
- )
45
- os.system("rm -r git/xuebinqin/DIS/IS-Net/__pycache__")
46
- os.system("mv git/xuebinqin/DIS/IS-Net/* .")
47
-
48
- import data_loader_cache
49
- import models
50
-
51
- device = "cpu"
52
- ISNetDIS = models.ISNetDIS
53
- normalize = data_loader_cache.normalize
54
- im_preprocess = data_loader_cache.im_preprocess
55
-
56
- # Set Parameters
57
- hypar = {} # paramters for inferencing
58
-
59
- # load trained weights from this path
60
- hypar["model_path"] = "./saved_models"
61
- # name of the to-be-loaded weights
62
- hypar["restore_model"] = "isnet-general-use.pth"
63
- # indicate if activate intermediate feature supervision
64
- hypar["interm_sup"] = False
65
-
66
- # choose floating point accuracy --
67
- # indicates "half" or "full" accuracy of float number
68
- hypar["model_digit"] = "full"
69
- hypar["seed"] = 0
70
-
71
- # cached input spatial resolution, can be configured into different size
72
- hypar["cache_size"] = [1024, 1024]
73
-
74
- # data augmentation parameters ---
75
- # mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
76
- hypar["input_size"] = [1024, 1024]
77
- # random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
78
- hypar["crop_size"] = [1024, 1024]
79
-
80
- hypar["model"] = ISNetDIS()
81
-
82
- # Build Model
83
- net = build_model(hypar, device)
84
-
85
-
86
- def predict(net, inputs_val, shapes_val, hypar, device):
87
- """
88
- Given an Image, predict the mask
89
- """
90
- net.eval()
91
-
92
- if hypar["model_digit"] == "full":
93
- inputs_val = inputs_val.type(torch.FloatTensor)
94
- else:
95
- inputs_val = inputs_val.type(torch.HalfTensor)
96
-
97
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(
98
- device
99
- ) # wrap inputs in Variable
100
-
101
- ds_val = net(inputs_val_v)[0] # list of 6 results
102
-
103
- # B x 1 x H x W # we want the first one which is the most accurate prediction
104
- pred_val = ds_val[0][0, :, :, :]
105
-
106
- # recover the prediction spatial size to the orignal image size
107
- pred_val = torch.squeeze(
108
- F.upsample(
109
- torch.unsqueeze(pred_val, 0),
110
- (shapes_val[0][0], shapes_val[0][1]),
111
- mode="bilinear",
112
  )
113
- )
 
114
 
115
- ma = torch.max(pred_val)
116
- mi = torch.min(pred_val)
117
- pred_val = (pred_val - mi) / (ma - mi) # max = 1
118
 
119
- if device == "cpu":
120
- torch.cpu.empty_cache()
121
- # it is the mask we need
122
- return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
 
 
123
 
 
124
 
125
- def load_image(im_pil, hypar):
126
- im = np.array(im_pil)
127
- im, im_shp = im_preprocess(im, hypar["cache_size"])
128
- im = torch.divide(im, 255.0)
129
- shape = torch.from_numpy(np.array(im_shp))
130
- # make a batch of image, shape
131
- aa = normalize(im, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
132
- return aa.unsqueeze(0), shape.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
- def remove_background(image):
136
  image_tensor, orig_size = load_image(image, hypar)
137
  mask = predict(net, image_tensor, orig_size, hypar, "cpu")
138
 
@@ -141,8 +137,12 @@ def remove_background(image):
141
 
142
  cropped = im_rgb.copy()
143
  cropped.putalpha(mask)
 
144
  return cropped
145
 
 
 
 
146
 
147
  inputs = gr.inputs.Image()
148
  outputs = gr.outputs.Image(type="pil")
 
9
  from PIL import Image
10
 
11
 
12
+ def removeBackground(image):
13
+ if not os.path.exists("saved_models"):
14
+ os.mkdir("saved_models")
15
+ os.mkdir("git")
16
+ os.system("git clone https://github.com/xuebinqin/DIS git/xuebinqin/DIS")
17
+ hf_hub_download(
18
+ repo_id="NimaBoscarino/IS-Net_DIS-general-use",
19
+ filename="isnet-general-use.pth",
20
+ local_dir="saved_models",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
+ os.system("rm -r git/xuebinqin/DIS/IS-Net/__pycache__")
23
+ os.system("mv git/xuebinqin/DIS/IS-Net/* .")
24
 
25
+ def build_model(hypar, device):
26
+ net = hypar["model"] # GOSNETINC(3,1)
 
27
 
28
+ # convert to half precision
29
+ if hypar["model_digit"] == "half":
30
+ net.half()
31
+ for layer in net.modules():
32
+ if isinstance(layer, nn.BatchNorm2d):
33
+ layer.float()
34
 
35
+ net.to(device)
36
 
37
+ if hypar["restore_model"] != "":
38
+ net.load_state_dict(
39
+ torch.load(
40
+ hypar["model_path"] + "/" + hypar["restore_model"],
41
+ map_location=device,
42
+ )
43
+ )
44
+ net.to(device)
45
+ net.eval()
46
+ return net
47
+
48
+ import data_loader_cache
49
+ import models
50
+
51
+ device = "cpu"
52
+ ISNetDIS = models.ISNetDIS
53
+ normalize = data_loader_cache.normalize
54
+ im_preprocess = data_loader_cache.im_preprocess
55
+
56
+ # Set Parameters
57
+ hypar = {} # paramters for inferencing
58
+
59
+ # load trained weights from this path
60
+ hypar["model_path"] = "./saved_models"
61
+ # name of the to-be-loaded weights
62
+ hypar["restore_model"] = "isnet-general-use.pth"
63
+ # indicate if activate intermediate feature supervision
64
+ hypar["interm_sup"] = False
65
+
66
+ # choose floating point accuracy --
67
+ # indicates "half" or "full" accuracy of float number
68
+ hypar["model_digit"] = "full"
69
+ hypar["seed"] = 0
70
+
71
+ # cached input spatial resolution, can be configured into different size
72
+ hypar["cache_size"] = [1024, 1024]
73
+
74
+ # data augmentation parameters ---
75
+ # mdoel input spatial size, usually use the same value hypar["cache_size"], which means we don't further resize the images
76
+ hypar["input_size"] = [1024, 1024]
77
+ # random crop size from the input, it is usually set as smaller than hypar["cache_size"], e.g., [920,920] for data augmentation
78
+ hypar["crop_size"] = [1024, 1024]
79
+
80
+ hypar["model"] = ISNetDIS()
81
+
82
+ # Build Model
83
+ net = build_model(hypar, device)
84
+
85
+ def predict(net, inputs_val, shapes_val, hypar, device):
86
+ """
87
+ Given an Image, predict the mask
88
+ """
89
+ net.eval()
90
+
91
+ if hypar["model_digit"] == "full":
92
+ inputs_val = inputs_val.type(torch.FloatTensor)
93
+ else:
94
+ inputs_val = inputs_val.type(torch.HalfTensor)
95
+
96
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(
97
+ device
98
+ ) # wrap inputs in Variable
99
+
100
+ ds_val = net(inputs_val_v)[0] # list of 6 results
101
+
102
+ # B x 1 x H x W # we want the first one which is the most accurate prediction
103
+ pred_val = ds_val[0][0, :, :, :]
104
+
105
+ # recover the prediction spatial size to the orignal image size
106
+ pred_val = torch.squeeze(
107
+ F.upsample(
108
+ torch.unsqueeze(pred_val, 0),
109
+ (shapes_val[0][0], shapes_val[0][1]),
110
+ mode="bilinear",
111
+ )
112
+ )
113
 
114
+ ma = torch.max(pred_val)
115
+ mi = torch.min(pred_val)
116
+ pred_val = (pred_val - mi) / (ma - mi) # max = 1
117
+
118
+ if device == "cuda":
119
+ torch.cuda.empty_cache()
120
+ # it is the mask we need
121
+ return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
122
+
123
+ def load_image(im_pil, hypar):
124
+ im = np.array(im_pil)
125
+ im, im_shp = im_preprocess(im, hypar["cache_size"])
126
+ im = torch.divide(im, 255.0)
127
+ shape = torch.from_numpy(np.array(im_shp))
128
+ # make a batch of image, shape
129
+ aa = normalize(im, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
130
+ return aa.unsqueeze(0), shape.unsqueeze(0)
131
 
 
132
  image_tensor, orig_size = load_image(image, hypar)
133
  mask = predict(net, image_tensor, orig_size, hypar, "cpu")
134
 
 
137
 
138
  cropped = im_rgb.copy()
139
  cropped.putalpha(mask)
140
+
141
  return cropped
142
 
143
+ def remove_background(image):
144
+ return removeBackground(image)
145
+
146
 
147
  inputs = gr.inputs.Image()
148
  outputs = gr.outputs.Image(type="pil")