mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-06 11:00:48 +00:00
commit
7c3dff928f
33
demo_v2.py
33
demo_v2.py
@ -191,7 +191,6 @@ def visualize_all_bbox_together(image, generation):
|
|||||||
return None, ''
|
return None, ''
|
||||||
|
|
||||||
generation = html.unescape(generation)
|
generation = html.unescape(generation)
|
||||||
print('gen begin', generation)
|
|
||||||
|
|
||||||
image_width, image_height = image.size
|
image_width, image_height = image.size
|
||||||
image = image.resize([500, int(500 / image_width * image_height)])
|
image = image.resize([500, int(500 / image_width * image_height)])
|
||||||
@ -372,9 +371,7 @@ def visualize_all_bbox_together(image, generation):
|
|||||||
color = next(color_gen)
|
color = next(color_gen)
|
||||||
return f'<span style="color:rgb{color}">{phrase}</span>'
|
return f'<span style="color:rgb{color}">{phrase}</span>'
|
||||||
|
|
||||||
print('gen before', generation)
|
|
||||||
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
|
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
|
||||||
print('gen after', generation)
|
|
||||||
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
|
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
|
||||||
else:
|
else:
|
||||||
generation_colored = ''
|
generation_colored = ''
|
||||||
@ -395,31 +392,28 @@ def gradio_reset(chat_state, img_list):
|
|||||||
def image_upload_trigger(upload_flag, replace_flag, img_list):
|
def image_upload_trigger(upload_flag, replace_flag, img_list):
|
||||||
# set the upload flag to true when receive a new image.
|
# set the upload flag to true when receive a new image.
|
||||||
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||||
print('flag', upload_flag, replace_flag)
|
|
||||||
print("SET UPLOAD FLAG!")
|
|
||||||
upload_flag = 1
|
upload_flag = 1
|
||||||
if img_list:
|
if img_list:
|
||||||
print("SET REPLACE FLAG!")
|
|
||||||
replace_flag = 1
|
replace_flag = 1
|
||||||
print('flag', upload_flag, replace_flag)
|
|
||||||
return upload_flag, replace_flag
|
return upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
|
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
|
||||||
# set the upload flag to true when receive a new image.
|
# set the upload flag to true when receive a new image.
|
||||||
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||||
print('flag', upload_flag, replace_flag)
|
|
||||||
print("SET UPLOAD FLAG!")
|
|
||||||
upload_flag = 1
|
upload_flag = 1
|
||||||
if img_list or replace_flag == 1:
|
if img_list or replace_flag == 1:
|
||||||
print("SET REPLACE FLAG!")
|
|
||||||
replace_flag = 1
|
replace_flag = 1
|
||||||
|
|
||||||
print('flag', upload_flag, replace_flag)
|
|
||||||
return upload_flag, replace_flag
|
return upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
|
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
|
||||||
|
if len(user_message) == 0:
|
||||||
|
text_box_show = 'Input should not be empty!'
|
||||||
|
else:
|
||||||
|
text_box_show = ''
|
||||||
|
|
||||||
if isinstance(gr_img, dict):
|
if isinstance(gr_img, dict):
|
||||||
gr_img, mask = gr_img['image'], gr_img['mask']
|
gr_img, mask = gr_img['image'], gr_img['mask']
|
||||||
else:
|
else:
|
||||||
@ -432,20 +426,14 @@ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag,
|
|||||||
bbox = mask2bbox(mask)
|
bbox = mask2bbox(mask)
|
||||||
user_message = user_message + bbox
|
user_message = user_message + bbox
|
||||||
|
|
||||||
if len(user_message) == 0:
|
|
||||||
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
|
||||||
|
|
||||||
if chat_state is None:
|
if chat_state is None:
|
||||||
chat_state = CONV_VISION.copy()
|
chat_state = CONV_VISION.copy()
|
||||||
|
|
||||||
print('upload flag: {}'.format(upload_flag))
|
|
||||||
if upload_flag:
|
if upload_flag:
|
||||||
if replace_flag:
|
if replace_flag:
|
||||||
print('RESET!!!!!!!')
|
|
||||||
chat_state = CONV_VISION.copy() # new image, reset everything
|
chat_state = CONV_VISION.copy() # new image, reset everything
|
||||||
replace_flag = 0
|
replace_flag = 0
|
||||||
chatbot = []
|
chatbot = []
|
||||||
print('UPLOAD IMAGE!!')
|
|
||||||
img_list = []
|
img_list = []
|
||||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||||
upload_flag = 0
|
upload_flag = 0
|
||||||
@ -457,11 +445,10 @@ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag,
|
|||||||
if '[identify]' in user_message:
|
if '[identify]' in user_message:
|
||||||
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
|
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
|
||||||
if visual_img is not None:
|
if visual_img is not None:
|
||||||
print('Visualizing the input')
|
|
||||||
file_path = save_tmp_img(visual_img)
|
file_path = save_tmp_img(visual_img)
|
||||||
chatbot = chatbot + [[(file_path,), None]]
|
chatbot = chatbot + [[(file_path,), None]]
|
||||||
|
|
||||||
return '', chatbot, chat_state, img_list, upload_flag, replace_flag
|
return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
|
||||||
|
|
||||||
|
|
||||||
def gradio_answer(chatbot, chat_state, img_list, temperature):
|
def gradio_answer(chatbot, chat_state, img_list, temperature):
|
||||||
@ -475,9 +462,9 @@ def gradio_answer(chatbot, chat_state, img_list, temperature):
|
|||||||
|
|
||||||
|
|
||||||
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
||||||
print('chat state', chat_state.get_prompt())
|
if len(img_list) > 0:
|
||||||
if not isinstance(img_list[0], torch.Tensor):
|
if not isinstance(img_list[0], torch.Tensor):
|
||||||
chat.encode_img(img_list)
|
chat.encode_img(img_list)
|
||||||
streamer = chat.stream_answer(conv=chat_state,
|
streamer = chat.stream_answer(conv=chat_state,
|
||||||
img_list=img_list,
|
img_list=img_list,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@ -489,7 +476,6 @@ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
|||||||
output += escapped
|
output += escapped
|
||||||
chatbot[-1][1] = output
|
chatbot[-1][1] = output
|
||||||
yield chatbot, chat_state
|
yield chatbot, chat_state
|
||||||
# print('message: ', chat_state.messages)
|
|
||||||
chat_state.messages[-1][1] = '</s>'
|
chat_state.messages[-1][1] = '</s>'
|
||||||
return chatbot, chat_state
|
return chatbot, chat_state
|
||||||
|
|
||||||
@ -501,7 +487,6 @@ def gradio_visualize(chatbot, gr_img):
|
|||||||
unescaped = reverse_escape(chatbot[-1][1])
|
unescaped = reverse_escape(chatbot[-1][1])
|
||||||
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
|
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
|
||||||
if visual_img is not None:
|
if visual_img is not None:
|
||||||
print('Visualizing the output')
|
|
||||||
if len(generation_color):
|
if len(generation_color):
|
||||||
chatbot[-1][1] = generation_color
|
chatbot[-1][1] = generation_color
|
||||||
file_path = save_tmp_img(visual_img)
|
file_path = save_tmp_img(visual_img)
|
||||||
|
Loading…
Reference in New Issue
Block a user