import argparse import os import random import numpy as np import torch import torch.backends.cudnn as cudnn import gradio as gr from minigpt4.common.config import Config from minigpt4.common.dist_utils import get_rank from minigpt4.common.registry import registry from minigpt4.conversation.conversation import Chat, CONV_VISION # imports modules for registration from minigpt4.datasets.builders import * from minigpt4.models import * from minigpt4.processors import * from minigpt4.runners import * from minigpt4.tasks import * def parse_args(): parser = argparse.ArgumentParser(description="Demo") parser.add_argument("--cfg-path", required=True, help="path to configuration file.") parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") parser.add_argument( "--options", nargs="+", help="override some settings in the used config, the key-value pair " "in xxx=yyy format will be merged into config file (deprecate), " "change to --cfg-options instead.", ) args = parser.parse_args() return args def setup_seeds(config): seed = config.run_cfg.seed + get_rank() random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) cudnn.benchmark = False cudnn.deterministic = True # ======================================== # Model Initialization # ======================================== print('Initializing Chat') args = parse_args() cfg = Config(args) model_config = cfg.model_cfg model_config.device_8bit = args.gpu_id model_cls = registry.get_model_class(model_config.arch) model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) print('Initialization Finished') # ======================================== # Gradio Setting # ======================================== def gradio_reset(): # reset chatbot, image, text_input, upload_button, chat_state, img_list, img_emb_list, gallery return None, \ gr.update(value=None, interactive=True), \ gr.update(placeholder='Please upload your image first', interactive=False), \ gr.update(value="Upload & Start Chat", interactive=True), \ CONV_VISION.copy(), \ [], \ [], \ [] def upload_img(gr_img, chat_state, img_list, img_emb_list): if gr_img is None: return None, None, gr.update(interactive=True), chat_state, img_list, img_emb_list img_list.append(gr_img) # upload an image to the chat chat.upload_img(gr_img, chat_state, img_emb_list) # update image, text_input, upload_button, chat_state, gallery, img_emb_list return gr.update(value=None, interactive=False), \ gr.update(interactive=True, placeholder='Type and press Enter'), \ gr.update(value="Send more images after sending a message", interactive=False), \ chat_state, \ img_list, \ img_emb_list def gradio_ask(user_message, chatbot, chat_state): if len(user_message) == 0: return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state chat.ask(user_message, chat_state) chatbot = chatbot + [[user_message, None]] return '', chatbot, chat_state def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): llm_message = chat.answer(conv=chat_state, img_list=img_list, num_beams=num_beams, temperature=temperature, max_new_tokens=300, max_length=2000)[0] chatbot[-1][1] = llm_message # update chatbot, chat_state, image, upload_button return chatbot, \ chat_state, \ gr.update(interactive=True), \ gr.update(value="Send more image", interactive=True) title = """