This commit is contained in:
Zhijie Lin 2023-05-25 16:25:06 +08:00
parent 5204a512e8
commit e8f046f414
3 changed files with 17 additions and 14 deletions

25
demo.py
View File

@ -51,11 +51,12 @@ print('Initializing Chat')
args = parse_args()
cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
print(model_config)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
# model_config = cfg.model_cfg
# model_config.device_8bit = args.gpu_id
# model_cls = registry.get_model_class(model_config.arch)
# print(model_config)
# model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
model = None
# TODO: Fix hard-coding `cc12m`
vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train
@ -81,6 +82,7 @@ def gradio_reset(chat_state, emb_list):
def upload_img(gr_img, text_input, chat_state):
return None, None, gr.update(interactive=True), chat_state, None
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = CONV_VISION.copy()
@ -101,12 +103,13 @@ def gradio_ask(user_message, chatbot, chat_state):
def gradio_answer(chatbot, chat_state, emb_list, num_beams, temperature):
llm_message = chat.answer(conversation=chat_state,
emb_list=emb_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=300,
max_length=2000)[0]
# llm_message = chat.answer(conversation=chat_state,
# emb_list=emb_list,
# num_beams=num_beams,
# temperature=temperature,
# max_new_tokens=300,
# max_length=2000)[0]
llm_message = "I don't know"
chatbot[-1][1] = llm_message
return chatbot, chat_state, emb_list

View File

@ -8,7 +8,7 @@ model:
low_resource: True
prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/pretrained/ckpt/'
ckpt: 'checkpoints/prerained_minigpt4_7b.pth'
datasets:

View File

@ -10,11 +10,11 @@ model:
freeze_qformer: True
# Q-Former
q_former_model: "/mnt/bn/bykang/chixma/data/pretrained_models/blip2_pretrained_flant5xxl.pth"
q_former_model: "checkpoints/blip2_pretrained_flant5xxl.pth"
num_query_token: 32
# Vicuna
llama_model: "/mnt/bn/bykang/chixma/data/pretrained_models/vicuna-13b-v0/"
llama_model: "checkpoints/blip2_pretrained_flant5xxl.pth"
# generation configs
prompt: ""