tttoaster commited on
Commit
281a32c
·
verified ·
1 Parent(s): 8c7cb04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -14
app.py CHANGED
@@ -187,7 +187,7 @@ class LLMService:
187
 
188
  model_id_or_path = "stablediffusionapi/realistic-vision-v51"
189
  self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None, torch_dtype=torch.float16)
190
- self.vae_pipe = self.vae_pipe.to(self.vit_sd_device)
191
 
192
  self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
193
  self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
@@ -207,7 +207,7 @@ class LLMService:
207
  service = LLMService(args)
208
 
209
  @spaces.GPU
210
- def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox):
211
  with torch.no_grad():
212
  text_list = text_list.split(IMG_FLAG)
213
  top_p = 0.5
@@ -360,14 +360,21 @@ def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox):
360
  img_feat = img_gen_feat[img_idx:img_idx + 1]
361
  generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0]
362
 
363
- init_image = generated_image.resize((1024, 1024))
364
- prompt = ""
365
- images = service.vae_pipe(prompt=prompt, image=init_image,
366
- num_inference_steps=50, guidance_scale=8.0, strength=0.38).images
367
- generated_image = images[0]
368
-
369
- image_base64 = encode_image(generated_image)
370
- gen_imgs_base64_list.append(image_base64)
 
 
 
 
 
 
 
371
 
372
  # print('loading visual encoder and llm to GPU, and sd to CPU')
373
  # a = time.time()
@@ -387,7 +394,7 @@ def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox):
387
  print(input_text + generated_text)
388
  return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg}
389
 
390
- def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox,
391
  request: gr.Request):
392
  print('input_state:', input_state)
393
 
@@ -409,7 +416,7 @@ def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_g
409
  force_boi = force_image_gen
410
  force_bbox = force_bbox
411
 
412
- results = generate(text, images, max_new_tokens, force_boi, force_bbox)
413
  print('response: ', {'text': results['text'], 'error_msg': results['error_msg']})
414
 
415
  output_state = init_input_state()
@@ -652,6 +659,8 @@ If you want to experience the normal model inference speed, you can run [[Infere
652
  * You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
653
 
654
  * You can click "Force Bounding Box" to compel the model to produce bounding box for object detection.
 
 
655
 
656
  * SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
657
 
@@ -755,6 +764,7 @@ if __name__ == '__main__':
755
  label="Max History Rounds")
756
  force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
757
  force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box')
 
758
 
759
  with gr.Column(scale=7):
760
  chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I", height=700)
@@ -776,7 +786,7 @@ if __name__ == '__main__':
776
  downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
777
 
778
  regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
779
- http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox],
780
  [dialog_state, input_state, chatbot] + btn_list)
781
  add_image_btn.click(add_image, [dialog_state, input_state, image],
782
  [dialog_state, input_state, image, chatbot] + btn_list)
@@ -789,7 +799,7 @@ if __name__ == '__main__':
789
  add_text, [dialog_state, input_state, text],
790
  [dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
791
  http_bot,
792
- [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox],
793
  [dialog_state, input_state, chatbot] + btn_list)
794
  clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
795
 
 
187
 
188
  model_id_or_path = "stablediffusionapi/realistic-vision-v51"
189
  self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None, torch_dtype=torch.float16)
190
+ self.vae_pipe = self.vae_pipe.cpu()
191
 
192
  self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
193
  self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
 
207
  service = LLMService(args)
208
 
209
  @spaces.GPU
210
+ def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox, force_polish):
211
  with torch.no_grad():
212
  text_list = text_list.split(IMG_FLAG)
213
  top_p = 0.5
 
360
  img_feat = img_gen_feat[img_idx:img_idx + 1]
361
  generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0]
362
 
363
+ if force_polish:
364
+ service.sd_adapter = service.sd_adapter.cpu()
365
+ service.vae_pipe = service.vae_pipe.to(service.vit_sd_device, dtype=service.dtype)
366
+
367
+ init_image = generated_image.resize((1024, 1024))
368
+ prompt = ""
369
+ images = service.vae_pipe(prompt=prompt, image=init_image,
370
+ num_inference_steps=50, guidance_scale=8.0, strength=0.38).images
371
+ generated_image = images[0]
372
+
373
+ image_base64 = encode_image(generated_image)
374
+ gen_imgs_base64_list.append(image_base64)
375
+
376
+ service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
377
+ service.vae_pipe = service.vae_pipe.cpu()
378
 
379
  # print('loading visual encoder and llm to GPU, and sd to CPU')
380
  # a = time.time()
 
394
  print(input_text + generated_text)
395
  return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg}
396
 
397
+ def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox, force_polish,
398
  request: gr.Request):
399
  print('input_state:', input_state)
400
 
 
416
  force_boi = force_image_gen
417
  force_bbox = force_bbox
418
 
419
+ results = generate(text, images, max_new_tokens, force_boi, force_bbox, force_polish)
420
  print('response: ', {'text': results['text'], 'error_msg': results['error_msg']})
421
 
422
  output_state = init_input_state()
 
659
  * You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
660
 
661
  * You can click "Force Bounding Box" to compel the model to produce bounding box for object detection.
662
+
663
+ * You can click "Force Polishing Generated Image" to compel the model to polish the generated image with image post-processing.
664
 
665
  * SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
666
 
 
764
  label="Max History Rounds")
765
  force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
766
  force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box')
767
+ force_polish = gr.Radio(choices=[True, False], value=True, label='Force Polishing Generated Image')
768
 
769
  with gr.Column(scale=7):
770
  chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I", height=700)
 
786
  downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
787
 
788
  regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
789
+ http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish],
790
  [dialog_state, input_state, chatbot] + btn_list)
791
  add_image_btn.click(add_image, [dialog_state, input_state, image],
792
  [dialog_state, input_state, image, chatbot] + btn_list)
 
799
  add_text, [dialog_state, input_state, text],
800
  [dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
801
  http_bot,
802
+ [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish],
803
  [dialog_state, input_state, chatbot] + btn_list)
804
  clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
805