This commit is contained in:
Zhijie Lin 2023-05-25 16:25:06 +08:00
parent 5204a512e8
commit e8f046f414
3 changed files with 17 additions and 14 deletions

25
demo.py
View File

@ -51,11 +51,12 @@ print('Initializing Chat')
args = parse_args() args = parse_args()
cfg = Config(args) cfg = Config(args)
model_config = cfg.model_cfg # model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id # model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch) # model_cls = registry.get_model_class(model_config.arch)
print(model_config) # print(model_config)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) # model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
model = None
# TODO: Fix hard-coding `cc12m` # TODO: Fix hard-coding `cc12m`
vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train
@ -81,6 +82,7 @@ def gradio_reset(chat_state, emb_list):
def upload_img(gr_img, text_input, chat_state): def upload_img(gr_img, text_input, chat_state):
return None, None, gr.update(interactive=True), chat_state, None
if gr_img is None: if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None return None, None, gr.update(interactive=True), chat_state, None
chat_state = CONV_VISION.copy() chat_state = CONV_VISION.copy()
@ -101,12 +103,13 @@ def gradio_ask(user_message, chatbot, chat_state):
def gradio_answer(chatbot, chat_state, emb_list, num_beams, temperature): def gradio_answer(chatbot, chat_state, emb_list, num_beams, temperature):
llm_message = chat.answer(conversation=chat_state, # llm_message = chat.answer(conversation=chat_state,
emb_list=emb_list, # emb_list=emb_list,
num_beams=num_beams, # num_beams=num_beams,
temperature=temperature, # temperature=temperature,
max_new_tokens=300, # max_new_tokens=300,
max_length=2000)[0] # max_length=2000)[0]
llm_message = "I don't know"
chatbot[-1][1] = llm_message chatbot[-1][1] = llm_message
return chatbot, chat_state, emb_list return chatbot, chat_state, emb_list

View File

@ -8,7 +8,7 @@ model:
low_resource: True low_resource: True
prompt_path: "prompts/alignment.txt" prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: ' prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/pretrained/ckpt/' ckpt: 'checkpoints/prerained_minigpt4_7b.pth'
datasets: datasets:

View File

@ -10,11 +10,11 @@ model:
freeze_qformer: True freeze_qformer: True
# Q-Former # Q-Former
q_former_model: "/mnt/bn/bykang/chixma/data/pretrained_models/blip2_pretrained_flant5xxl.pth" q_former_model: "checkpoints/blip2_pretrained_flant5xxl.pth"
num_query_token: 32 num_query_token: 32
# Vicuna # Vicuna
llama_model: "/mnt/bn/bykang/chixma/data/pretrained_models/vicuna-13b-v0/" llama_model: "checkpoints/blip2_pretrained_flant5xxl.pth"
# generation configs # generation configs
prompt: "" prompt: ""