Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -187,7 +187,7 @@ class LLMService:
|
|
187 |
|
188 |
model_id_or_path = "stablediffusionapi/realistic-vision-v51"
|
189 |
self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None, torch_dtype=torch.float16)
|
190 |
-
self.vae_pipe = self.vae_pipe.
|
191 |
|
192 |
self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
193 |
self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
@@ -207,7 +207,7 @@ class LLMService:
|
|
207 |
service = LLMService(args)
|
208 |
|
209 |
@spaces.GPU
|
210 |
-
def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox):
|
211 |
with torch.no_grad():
|
212 |
text_list = text_list.split(IMG_FLAG)
|
213 |
top_p = 0.5
|
@@ -360,14 +360,21 @@ def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox):
|
|
360 |
img_feat = img_gen_feat[img_idx:img_idx + 1]
|
361 |
generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0]
|
362 |
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
# print('loading visual encoder and llm to GPU, and sd to CPU')
|
373 |
# a = time.time()
|
@@ -387,7 +394,7 @@ def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox):
|
|
387 |
print(input_text + generated_text)
|
388 |
return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg}
|
389 |
|
390 |
-
def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox,
|
391 |
request: gr.Request):
|
392 |
print('input_state:', input_state)
|
393 |
|
@@ -409,7 +416,7 @@ def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_g
|
|
409 |
force_boi = force_image_gen
|
410 |
force_bbox = force_bbox
|
411 |
|
412 |
-
results = generate(text, images, max_new_tokens, force_boi, force_bbox)
|
413 |
print('response: ', {'text': results['text'], 'error_msg': results['error_msg']})
|
414 |
|
415 |
output_state = init_input_state()
|
@@ -652,6 +659,8 @@ If you want to experience the normal model inference speed, you can run [[Infere
|
|
652 |
* You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
|
653 |
|
654 |
* You can click "Force Bounding Box" to compel the model to produce bounding box for object detection.
|
|
|
|
|
655 |
|
656 |
* SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
|
657 |
|
@@ -755,6 +764,7 @@ if __name__ == '__main__':
|
|
755 |
label="Max History Rounds")
|
756 |
force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
|
757 |
force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box')
|
|
|
758 |
|
759 |
with gr.Column(scale=7):
|
760 |
chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I", height=700)
|
@@ -776,7 +786,7 @@ if __name__ == '__main__':
|
|
776 |
downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
|
777 |
|
778 |
regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
|
779 |
-
http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox],
|
780 |
[dialog_state, input_state, chatbot] + btn_list)
|
781 |
add_image_btn.click(add_image, [dialog_state, input_state, image],
|
782 |
[dialog_state, input_state, image, chatbot] + btn_list)
|
@@ -789,7 +799,7 @@ if __name__ == '__main__':
|
|
789 |
add_text, [dialog_state, input_state, text],
|
790 |
[dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
|
791 |
http_bot,
|
792 |
-
[dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox],
|
793 |
[dialog_state, input_state, chatbot] + btn_list)
|
794 |
clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
|
795 |
|
|
|
187 |
|
188 |
model_id_or_path = "stablediffusionapi/realistic-vision-v51"
|
189 |
self.vae_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, safety_checker=None, torch_dtype=torch.float16)
|
190 |
+
self.vae_pipe = self.vae_pipe.cpu()
|
191 |
|
192 |
self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0]
|
193 |
self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0]
|
|
|
207 |
service = LLMService(args)
|
208 |
|
209 |
@spaces.GPU
|
210 |
+
def generate(text_list, image_list, max_new_tokens, force_boi, force_bbox, force_polish):
|
211 |
with torch.no_grad():
|
212 |
text_list = text_list.split(IMG_FLAG)
|
213 |
top_p = 0.5
|
|
|
360 |
img_feat = img_gen_feat[img_idx:img_idx + 1]
|
361 |
generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0]
|
362 |
|
363 |
+
if force_polish:
|
364 |
+
service.sd_adapter = service.sd_adapter.cpu()
|
365 |
+
service.vae_pipe = service.vae_pipe.to(service.vit_sd_device, dtype=service.dtype)
|
366 |
+
|
367 |
+
init_image = generated_image.resize((1024, 1024))
|
368 |
+
prompt = ""
|
369 |
+
images = service.vae_pipe(prompt=prompt, image=init_image,
|
370 |
+
num_inference_steps=50, guidance_scale=8.0, strength=0.38).images
|
371 |
+
generated_image = images[0]
|
372 |
+
|
373 |
+
image_base64 = encode_image(generated_image)
|
374 |
+
gen_imgs_base64_list.append(image_base64)
|
375 |
+
|
376 |
+
service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype)
|
377 |
+
service.vae_pipe = service.vae_pipe.cpu()
|
378 |
|
379 |
# print('loading visual encoder and llm to GPU, and sd to CPU')
|
380 |
# a = time.time()
|
|
|
394 |
print(input_text + generated_text)
|
395 |
return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg}
|
396 |
|
397 |
+
def http_bot(dialog_state, input_state, max_new_tokens, max_turns, force_image_gen, force_bbox, force_polish,
|
398 |
request: gr.Request):
|
399 |
print('input_state:', input_state)
|
400 |
|
|
|
416 |
force_boi = force_image_gen
|
417 |
force_bbox = force_bbox
|
418 |
|
419 |
+
results = generate(text, images, max_new_tokens, force_boi, force_bbox, force_polish)
|
420 |
print('response: ', {'text': results['text'], 'error_msg': results['error_msg']})
|
421 |
|
422 |
output_state = init_input_state()
|
|
|
659 |
* You can click "Force Image Generation" to compel the model to produce images when necessary. For example, our model might struggle to generate images when there is an excessive amount of text-only context.
|
660 |
|
661 |
* You can click "Force Bounding Box" to compel the model to produce bounding box for object detection.
|
662 |
+
|
663 |
+
* You can click "Force Polishing Generated Image" to compel the model to polish the generated image with image post-processing.
|
664 |
|
665 |
* SEED-X was trained with English-only data. It may process with other languages due to the inherent capabilities from LLaMA, but might not stable.
|
666 |
|
|
|
764 |
label="Max History Rounds")
|
765 |
force_img_gen = gr.Radio(choices=[True, False], value=False, label='Force Image Generation')
|
766 |
force_bbox = gr.Radio(choices=[True, False], value=False, label='Force Bounding Box')
|
767 |
+
force_polish = gr.Radio(choices=[True, False], value=True, label='Force Polishing Generated Image')
|
768 |
|
769 |
with gr.Column(scale=7):
|
770 |
chatbot = gr.Chatbot(elem_id='chatbot', label="SEED-X-I", height=700)
|
|
|
786 |
downvote_btn.click(downvote_last_response, [dialog_state], [upvote_btn, downvote_btn])
|
787 |
|
788 |
regenerate_btn.click(regenerate, [dialog_state], [dialog_state, chatbot] + btn_list).then(
|
789 |
+
http_bot, [dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish],
|
790 |
[dialog_state, input_state, chatbot] + btn_list)
|
791 |
add_image_btn.click(add_image, [dialog_state, input_state, image],
|
792 |
[dialog_state, input_state, image, chatbot] + btn_list)
|
|
|
799 |
add_text, [dialog_state, input_state, text],
|
800 |
[dialog_state, input_state, text, chatbot, upvote_btn, downvote_btn, regenerate_btn, clear_btn]).then(
|
801 |
http_bot,
|
802 |
+
[dialog_state, input_state, max_new_tokens, max_turns, force_img_gen, force_bbox, force_polish],
|
803 |
[dialog_state, input_state, chatbot] + btn_list)
|
804 |
clear_btn.click(clear_history, None, [dialog_state, input_state, chatbot] + btn_list)
|
805 |
|