diff --git a/demo.py b/demo.py index cf93edf..3239021 100644 --- a/demo.py +++ b/demo.py @@ -51,11 +51,12 @@ print('Initializing Chat') args = parse_args() cfg = Config(args) -model_config = cfg.model_cfg -model_config.device_8bit = args.gpu_id -model_cls = registry.get_model_class(model_config.arch) -print(model_config) -model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) +# model_config = cfg.model_cfg +# model_config.device_8bit = args.gpu_id +# model_cls = registry.get_model_class(model_config.arch) +# print(model_config) +# model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) +model = None # TODO: Fix hard-coding `cc12m` vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train @@ -81,6 +82,7 @@ def gradio_reset(chat_state, emb_list): def upload_img(gr_img, text_input, chat_state): + return None, None, gr.update(interactive=True), chat_state, None if gr_img is None: return None, None, gr.update(interactive=True), chat_state, None chat_state = CONV_VISION.copy() @@ -101,12 +103,13 @@ def gradio_ask(user_message, chatbot, chat_state): def gradio_answer(chatbot, chat_state, emb_list, num_beams, temperature): - llm_message = chat.answer(conversation=chat_state, - emb_list=emb_list, - num_beams=num_beams, - temperature=temperature, - max_new_tokens=300, - max_length=2000)[0] + # llm_message = chat.answer(conversation=chat_state, + # emb_list=emb_list, + # num_beams=num_beams, + # temperature=temperature, + # max_new_tokens=300, + # max_length=2000)[0] + llm_message = "I don't know" chatbot[-1][1] = llm_message return chatbot, chat_state, emb_list diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index f9e55a3..b9bfd1d 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -8,7 +8,7 @@ model: low_resource: True prompt_path: "prompts/alignment.txt" prompt_template: '###Human: {} ###Assistant: ' - ckpt: '/path/to/pretrained/ckpt/' + ckpt: 'checkpoints/prerained_minigpt4_7b.pth' datasets: diff --git a/minigpt4/configs/models/minigpt4.yaml b/minigpt4/configs/models/minigpt4.yaml index 9e7f8e9..2d5f13a 100644 --- a/minigpt4/configs/models/minigpt4.yaml +++ b/minigpt4/configs/models/minigpt4.yaml @@ -10,11 +10,11 @@ model: freeze_qformer: True # Q-Former - q_former_model: "/mnt/bn/bykang/chixma/data/pretrained_models/blip2_pretrained_flant5xxl.pth" + q_former_model: "checkpoints/blip2_pretrained_flant5xxl.pth" num_query_token: 32 # Vicuna - llama_model: "/mnt/bn/bykang/chixma/data/pretrained_models/vicuna-13b-v0/" + llama_model: "checkpoints/blip2_pretrained_flant5xxl.pth" # generation configs prompt: ""