Merge pull request #375 from TsuTikgiau/main

fix some gradio error
This commit is contained in:
ZhuDeyao 2023-10-15 21:44:40 +03:00 committed by GitHub
commit 7c3dff928f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -191,7 +191,6 @@ def visualize_all_bbox_together(image, generation):
return None, ''
generation = html.unescape(generation)
print('gen begin', generation)
image_width, image_height = image.size
image = image.resize([500, int(500 / image_width * image_height)])
@ -372,9 +371,7 @@ def visualize_all_bbox_together(image, generation):
color = next(color_gen)
return f'<span style="color:rgb{color}">{phrase}</span>'
print('gen before', generation)
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
print('gen after', generation)
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
else:
generation_colored = ''
@ -395,31 +392,28 @@ def gradio_reset(chat_state, img_list):
def image_upload_trigger(upload_flag, replace_flag, img_list):
# set the upload flag to true when receive a new image.
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
print('flag', upload_flag, replace_flag)
print("SET UPLOAD FLAG!")
upload_flag = 1
if img_list:
print("SET REPLACE FLAG!")
replace_flag = 1
print('flag', upload_flag, replace_flag)
return upload_flag, replace_flag
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
# set the upload flag to true when receive a new image.
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
print('flag', upload_flag, replace_flag)
print("SET UPLOAD FLAG!")
upload_flag = 1
if img_list or replace_flag == 1:
print("SET REPLACE FLAG!")
replace_flag = 1
print('flag', upload_flag, replace_flag)
return upload_flag, replace_flag
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
if len(user_message) == 0:
text_box_show = 'Input should not be empty!'
else:
text_box_show = ''
if isinstance(gr_img, dict):
gr_img, mask = gr_img['image'], gr_img['mask']
else:
@ -432,20 +426,14 @@ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag,
bbox = mask2bbox(mask)
user_message = user_message + bbox
if len(user_message) == 0:
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
if chat_state is None:
chat_state = CONV_VISION.copy()
print('upload flag: {}'.format(upload_flag))
if upload_flag:
if replace_flag:
print('RESET!!!!!!!')
chat_state = CONV_VISION.copy() # new image, reset everything
replace_flag = 0
chatbot = []
print('UPLOAD IMAGE!!')
img_list = []
llm_message = chat.upload_img(gr_img, chat_state, img_list)
upload_flag = 0
@ -457,11 +445,10 @@ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag,
if '[identify]' in user_message:
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
if visual_img is not None:
print('Visualizing the input')
file_path = save_tmp_img(visual_img)
chatbot = chatbot + [[(file_path,), None]]
return '', chatbot, chat_state, img_list, upload_flag, replace_flag
return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
def gradio_answer(chatbot, chat_state, img_list, temperature):
@ -475,9 +462,9 @@ def gradio_answer(chatbot, chat_state, img_list, temperature):
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
print('chat state', chat_state.get_prompt())
if not isinstance(img_list[0], torch.Tensor):
chat.encode_img(img_list)
if len(img_list) > 0:
if not isinstance(img_list[0], torch.Tensor):
chat.encode_img(img_list)
streamer = chat.stream_answer(conv=chat_state,
img_list=img_list,
temperature=temperature,
@ -489,7 +476,6 @@ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
output += escapped
chatbot[-1][1] = output
yield chatbot, chat_state
# print('message: ', chat_state.messages)
chat_state.messages[-1][1] = '</s>'
return chatbot, chat_state
@ -501,7 +487,6 @@ def gradio_visualize(chatbot, gr_img):
unescaped = reverse_escape(chatbot[-1][1])
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
if visual_img is not None:
print('Visualizing the output')
if len(generation_color):
chatbot[-1][1] = generation_color
file_path = save_tmp_img(visual_img)