mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
commit
7c3dff928f
33
demo_v2.py
33
demo_v2.py
@ -191,7 +191,6 @@ def visualize_all_bbox_together(image, generation):
|
||||
return None, ''
|
||||
|
||||
generation = html.unescape(generation)
|
||||
print('gen begin', generation)
|
||||
|
||||
image_width, image_height = image.size
|
||||
image = image.resize([500, int(500 / image_width * image_height)])
|
||||
@ -372,9 +371,7 @@ def visualize_all_bbox_together(image, generation):
|
||||
color = next(color_gen)
|
||||
return f'<span style="color:rgb{color}">{phrase}</span>'
|
||||
|
||||
print('gen before', generation)
|
||||
generation = re.sub(r'{<\d+><\d+><\d+><\d+>}|<delim>', '', generation)
|
||||
print('gen after', generation)
|
||||
generation_colored = re.sub(r'<p>(.*?)</p>', colored_phrases, generation)
|
||||
else:
|
||||
generation_colored = ''
|
||||
@ -395,31 +392,28 @@ def gradio_reset(chat_state, img_list):
|
||||
def image_upload_trigger(upload_flag, replace_flag, img_list):
|
||||
# set the upload flag to true when receive a new image.
|
||||
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||
print('flag', upload_flag, replace_flag)
|
||||
print("SET UPLOAD FLAG!")
|
||||
upload_flag = 1
|
||||
if img_list:
|
||||
print("SET REPLACE FLAG!")
|
||||
replace_flag = 1
|
||||
print('flag', upload_flag, replace_flag)
|
||||
return upload_flag, replace_flag
|
||||
|
||||
|
||||
def example_trigger(text_input, image, upload_flag, replace_flag, img_list):
|
||||
# set the upload flag to true when receive a new image.
|
||||
# if there is an old image (and old conversation), set the replace flag to true to reset the conv later.
|
||||
print('flag', upload_flag, replace_flag)
|
||||
print("SET UPLOAD FLAG!")
|
||||
upload_flag = 1
|
||||
if img_list or replace_flag == 1:
|
||||
print("SET REPLACE FLAG!")
|
||||
replace_flag = 1
|
||||
|
||||
print('flag', upload_flag, replace_flag)
|
||||
return upload_flag, replace_flag
|
||||
|
||||
|
||||
def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag, replace_flag):
|
||||
if len(user_message) == 0:
|
||||
text_box_show = 'Input should not be empty!'
|
||||
else:
|
||||
text_box_show = ''
|
||||
|
||||
if isinstance(gr_img, dict):
|
||||
gr_img, mask = gr_img['image'], gr_img['mask']
|
||||
else:
|
||||
@ -432,20 +426,14 @@ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag,
|
||||
bbox = mask2bbox(mask)
|
||||
user_message = user_message + bbox
|
||||
|
||||
if len(user_message) == 0:
|
||||
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
||||
|
||||
if chat_state is None:
|
||||
chat_state = CONV_VISION.copy()
|
||||
|
||||
print('upload flag: {}'.format(upload_flag))
|
||||
if upload_flag:
|
||||
if replace_flag:
|
||||
print('RESET!!!!!!!')
|
||||
chat_state = CONV_VISION.copy() # new image, reset everything
|
||||
replace_flag = 0
|
||||
chatbot = []
|
||||
print('UPLOAD IMAGE!!')
|
||||
img_list = []
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
upload_flag = 0
|
||||
@ -457,11 +445,10 @@ def gradio_ask(user_message, chatbot, chat_state, gr_img, img_list, upload_flag,
|
||||
if '[identify]' in user_message:
|
||||
visual_img, _ = visualize_all_bbox_together(gr_img, user_message)
|
||||
if visual_img is not None:
|
||||
print('Visualizing the input')
|
||||
file_path = save_tmp_img(visual_img)
|
||||
chatbot = chatbot + [[(file_path,), None]]
|
||||
|
||||
return '', chatbot, chat_state, img_list, upload_flag, replace_flag
|
||||
return text_box_show, chatbot, chat_state, img_list, upload_flag, replace_flag
|
||||
|
||||
|
||||
def gradio_answer(chatbot, chat_state, img_list, temperature):
|
||||
@ -475,9 +462,9 @@ def gradio_answer(chatbot, chat_state, img_list, temperature):
|
||||
|
||||
|
||||
def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
||||
print('chat state', chat_state.get_prompt())
|
||||
if not isinstance(img_list[0], torch.Tensor):
|
||||
chat.encode_img(img_list)
|
||||
if len(img_list) > 0:
|
||||
if not isinstance(img_list[0], torch.Tensor):
|
||||
chat.encode_img(img_list)
|
||||
streamer = chat.stream_answer(conv=chat_state,
|
||||
img_list=img_list,
|
||||
temperature=temperature,
|
||||
@ -489,7 +476,6 @@ def gradio_stream_answer(chatbot, chat_state, img_list, temperature):
|
||||
output += escapped
|
||||
chatbot[-1][1] = output
|
||||
yield chatbot, chat_state
|
||||
# print('message: ', chat_state.messages)
|
||||
chat_state.messages[-1][1] = '</s>'
|
||||
return chatbot, chat_state
|
||||
|
||||
@ -501,7 +487,6 @@ def gradio_visualize(chatbot, gr_img):
|
||||
unescaped = reverse_escape(chatbot[-1][1])
|
||||
visual_img, generation_color = visualize_all_bbox_together(gr_img, unescaped)
|
||||
if visual_img is not None:
|
||||
print('Visualizing the output')
|
||||
if len(generation_color):
|
||||
chatbot[-1][1] = generation_color
|
||||
file_path = save_tmp_img(visual_img)
|
||||
|
Loading…
Reference in New Issue
Block a user