lambertxiao commited on
Commit
ddebdd0
Β·
verified Β·
1 Parent(s): c50c7e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -33
app.py CHANGED
@@ -1,44 +1,32 @@
1
- # If this really is a HF Space, keep the next import;
2
- # otherwise comment it out and delete the decorator line below.
3
- import spaces # <─ ONLY needed in a Space
4
  import gradio as gr
5
- from transformers import AutoModel
6
  from PIL import Image
7
- import torch, numpy as np
 
8
 
9
  model_name_or_path = "lyttt/VLV_captioner"
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- model = AutoModel.from_pretrained(
13
- model_name_or_path,
14
- revision="master",
15
- trust_remote_code=True,
16
- low_cpu_mem_usage=False
17
- ).to(device)
18
-
19
- def drop_incomplete_tail(text: str) -> str:
20
- """Remove any unfinished sentence fragment at the end of `text`."""
21
- sentences = [s.strip() for s in text.split('.') if s.strip()]
22
  if not text.strip().endswith('.'):
23
- sentences = sentences[:-1]
24
- return '. '.join(sentences) + ('.' if sentences else '')
25
-
26
- # ───────────────────────────────────────────────────────────────
27
- @spaces.GPU(duration=120) # ← delete this line if **not** in a Space
28
- def greet(image):
29
- if image.dtype != np.uint8: # gradio gives float arr 0-1
30
- image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
31
 
32
- image_pil = Image.fromarray(image, mode="RGB")
33
-
34
- # The VLV-captioner accepts a list of PIL images directly.
35
  with torch.no_grad():
36
- # Second arg is max-new-tokens (kept from your original code).
37
- raw = model([image_pil], 300)
38
- text = raw.generated_text[0] if hasattr(raw, "generated_text") else raw[0]
39
 
40
- return drop_incomplete_tail(text)
41
- # ───────────────────────────────────────────────────────────────
 
 
 
 
42
 
43
  demo = gr.Interface(fn=greet, inputs="image", outputs="text")
44
- demo.launch()
 
1
+ import spaces
 
 
2
  import gradio as gr
3
+ from transformers import AutoModel, AutoProcessor
4
  from PIL import Image
5
+ import torch
6
+ import numpy as np
7
 
8
  model_name_or_path = "lyttt/VLV_captioner"
9
+ model = AutoModel.from_pretrained(model_name_or_path, revision="master", trust_remote_code=True,low_cpu_mem_usage=False)
10
 
11
+ def drop_incomplete_tail(text):
12
+ sentences = text.split('.')
13
+ complete_sentences = [s.strip() for s in sentences if s.strip()]
 
 
 
 
 
 
 
14
  if not text.strip().endswith('.'):
15
+ complete_sentences = complete_sentences[:-1]
16
+ return '. '.join(complete_sentences) + ('.' if complete_sentences else '')
 
 
 
 
 
 
17
 
18
+ @spaces.GPU(duration=120)
19
+ def caption_image(image):
 
20
  with torch.no_grad():
21
+ outputs = model([image], 300).generated_text[0]
22
+ return outputs
 
23
 
24
+ def greet(image):
25
+ if image.dtype != np.uint8:
26
+ image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
27
+ image = Image.fromarray(image, mode='RGB')
28
+ raw_text = caption_image(image)
29
+ return drop_incomplete_tail(raw_text)
30
 
31
  demo = gr.Interface(fn=greet, inputs="image", outputs="text")
32
+ demo.launch()