marnikitta commited on
Commit
4d4c3ca
·
verified ·
1 Parent(s): 5387e2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -1,12 +1,14 @@
1
  import gradio as gr
2
  from torch.nn import functional as F
 
3
 
4
  from model_loader import ModelType, type_to_transforms, type_to_loaded_model
5
 
6
  def get_y(model_type, model, image):
7
- if model_type == ModelType.SYNTHETIC_DETECTOR_V2:
8
- return model.forward(image.unsqueeze(0).to("cpu"))
9
- return model.forward(image[None, ...])
 
10
 
11
  def predict(raw_image, model_name):
12
  if model_name not in ModelType.get_list():
@@ -17,10 +19,11 @@ def predict(raw_image, model_name):
17
  tfm = type_to_transforms[model_type]
18
  image = tfm(raw_image)
19
  y = get_y(model_type, model, image)
20
- y_1 = F.softmax(y, dim=1)[:, 1].cpu().detach().numpy()
21
- y_2 = F.softmax(y, dim=1)[:, 0].cpu().detach().numpy()
22
- return {'created by AI': y_1.tolist(),
23
- 'created by human': y_2.tolist()}
 
24
 
25
  general_examples = [
26
  ["images/general/img_1.jpg"],
 
1
  import gradio as gr
2
  from torch.nn import functional as F
3
+ import torch
4
 
5
  from model_loader import ModelType, type_to_transforms, type_to_loaded_model
6
 
7
  def get_y(model_type, model, image):
8
+ with torch.no_grad():
9
+ if model_type == ModelType.SYNTHETIC_DETECTOR_V2:
10
+ return model.forward(image.unsqueeze(0).to("cpu"))
11
+ return model.forward(image[None, ...])
12
 
13
  def predict(raw_image, model_name):
14
  if model_name not in ModelType.get_list():
 
19
  tfm = type_to_transforms[model_type]
20
  image = tfm(raw_image)
21
  y = get_y(model_type, model, image)
22
+
23
+ y_2, y_1 = F.softmax(y, dim=1).cpu().numpy()[0]
24
+
25
+ return {'created by AI': float(y_2),
26
+ 'created by human': float(y_1)}
27
 
28
  general_examples = [
29
  ["images/general/img_1.jpg"],