stzhao commited on
Commit
a43380e
·
verified ·
1 Parent(s): 095edc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -17,7 +17,7 @@ def load_models():
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
  torch_dtype=torch.bfloat16,
20
- # device_map="auto"
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
@@ -26,7 +26,7 @@ def load_models():
26
  torch_dtype=torch.bfloat16
27
  )
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
- # pipe.to("cuda")
30
 
31
  return model, tokenizer, pipe
32
 
@@ -43,7 +43,7 @@ def truncate_caption_by_tokens(caption, max_tokens=256):
43
 
44
  @spaces.GPU(duration=50)
45
  def generate_enhanced_caption(image_caption, text_caption):
46
- model.to("cuda")
47
  """Generate enhanced caption using the LeX-Enhancer model"""
48
  combined_caption = f"{image_caption}, with the text on it: {text_caption}."
49
  instruction = """
@@ -77,7 +77,7 @@ Below is the simple caption of an image with text. Please deduce the detailed de
77
 
78
  @spaces.GPU(duration=60)
79
  def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
80
- pipe.to("cuda")
81
  pipe.enable_model_cpu_offload()
82
  """Generate image using LeX-Lumina"""
83
  # Truncate the caption if it's too long
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
  torch_dtype=torch.bfloat16,
20
+ device_map="auto"
21
  )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
 
 
26
  torch_dtype=torch.bfloat16
27
  )
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ pipe.to("cuda")
30
 
31
  return model, tokenizer, pipe
32
 
 
43
 
44
  @spaces.GPU(duration=50)
45
  def generate_enhanced_caption(image_caption, text_caption):
46
+ # model.to("cuda")
47
  """Generate enhanced caption using the LeX-Enhancer model"""
48
  combined_caption = f"{image_caption}, with the text on it: {text_caption}."
49
  instruction = """
 
77
 
78
  @spaces.GPU(duration=60)
79
  def generate_image(enhanced_caption, seed, num_inference_steps, guidance_scale):
80
+ # pipe.to("cuda")
81
  pipe.enable_model_cpu_offload()
82
  """Generate image using LeX-Lumina"""
83
  # Truncate the caption if it's too long