Rinawang commited on
Commit
2c40e39
·
verified ·
1 Parent(s): 3849d55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -1,31 +1,37 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, VisionEncoderDecoderModel
 
 
 
 
3
  from PIL import Image
4
  import torch
5
 
6
- # 加载模型和处理器
7
  model_id = "starvector/starvector-8b-im2svg"
8
- processor = AutoProcessor.from_pretrained(model_id)
 
 
 
9
  model = VisionEncoderDecoderModel.from_pretrained(model_id)
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model.to(device)
13
 
14
- # 定义推理函数
15
  def im2svg(image):
16
- inputs = processor(images=image, return_tensors="pt").to(device)
17
  outputs = model.generate(**inputs, max_new_tokens=1024)
18
- generated_svg = processor.batch_decode(outputs, skip_special_tokens=True)[0]
19
- return generated_svg
20
 
21
- # 创建 Gradio 界面
22
  demo = gr.Interface(
23
  fn=im2svg,
24
  inputs=gr.Image(type="pil"),
25
  outputs="text",
26
- title="StarVector 8B - Image to SVG",
27
- description="上传图像,将其转化为 SVG 矢量代码。",
28
- allow_flagging="never"
29
  )
30
 
31
  demo.launch()
 
1
  import gradio as gr
2
+ from transformers import (
3
+ SiglipImageProcessor,
4
+ RobertaTokenizerFast,
5
+ VisionEncoderDecoderModel
6
+ )
7
  from PIL import Image
8
  import torch
9
 
10
+ # 模型 ID
11
  model_id = "starvector/starvector-8b-im2svg"
12
+
13
+ # 分别加载 image processor 和 tokenizer
14
+ image_processor = SiglipImageProcessor.from_pretrained(model_id)
15
+ tokenizer = RobertaTokenizerFast.from_pretrained(model_id)
16
  model = VisionEncoderDecoderModel.from_pretrained(model_id)
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
  model.to(device)
20
 
21
+ # 推理函数
22
  def im2svg(image):
23
+ inputs = image_processor(images=image, return_tensors="pt").to(device)
24
  outputs = model.generate(**inputs, max_new_tokens=1024)
25
+ svg_code = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
26
+ return svg_code
27
 
28
+ # Gradio UI
29
  demo = gr.Interface(
30
  fn=im2svg,
31
  inputs=gr.Image(type="pil"),
32
  outputs="text",
33
+ title="🖼️ StarVector: Image SVG",
34
+ description="上传图像,我将它转化为矢量图(SVG 代码)。适用于简笔画、图标、草图。",
 
35
  )
36
 
37
  demo.launch()