diff --git a/README.md b/README.md index 7aa29f2..b1f8961 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Lin ### Launching Demo Locally -Try out our demo [demo.py](demo.py) on your local machine by running +Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running ``` python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 diff --git a/eval_configs/bindgpt4_eval.yaml b/eval_configs/bindgpt4_eval.yaml index 73d24e5..2b3282b 100644 --- a/eval_configs/bindgpt4_eval.yaml +++ b/eval_configs/bindgpt4_eval.yaml @@ -1,18 +1,18 @@ model: arch: bind_gpt4 model_type: pretrain_vicuna - freeze_vit: True + freeze_imagebind: True freeze_qformer: False max_txt_len: 160 end_sym: "###" low_resource: False prompt_path: "prompts/alignment.txt" prompt_template: '###Human: {} ###Assistant: ' - ckpt: '/path/to/pretrained/ckpt/' + ckpt: 'minigpt4/output/minigpt4_stage1_pretrain/20230524192/checkpoint_0.pth' datasets: - cc_sbu_align: + cc12m: # Double check vis_processor: train: name: "imagebind_vision_eval" diff --git a/minigpt4/conversation/__init__.py b/eval_scripts/__init__.py similarity index 100% rename from minigpt4/conversation/__init__.py rename to eval_scripts/__init__.py diff --git a/eval_scripts/conversation.py b/eval_scripts/conversation.py new file mode 100644 index 0000000..253950a --- /dev/null +++ b/eval_scripts/conversation.py @@ -0,0 +1,170 @@ +import dataclasses +from copy import deepcopy +from types import SimpleNamespace +from typing import List, Union, Dict + +import torch +from PIL import Image +from torch import nn, Tensor +from transformers import StoppingCriteria, StoppingCriteriaList + +from eval_scripts.eval_utils import load_image +from imagebind.models.image_bind import ModalityType +from minigpt4 import BaseProcessor + +Roles = SimpleNamespace( + HUMAN="Human", + ASSISTANT="Assistant" +) + + +class Message: + def __init__(self, role: str, content: Union[str, None]): + self.role = role + self.content = content + + +@dataclasses.dataclass +class Conversation: + """A class that keeps all conversation history.""" + system: str + messages: List[Message] + sep: str = "###" + + def get_prompt(self): + ret = self.system + self.sep + for message in self.messages: + if message.content: + ret += message.role + ": " + message.content + self.sep + else: + ret += message.role + ":" + return ret + + def append_message(self, role, content): + self.messages.append(Message(role, content)) + + def copy(self): + return Conversation( + system=self.system, + messages=deepcopy(self.messages), + sep=self.sep) + + def dict(self): + return { + "system": self.system, + "messages": [(msg.role, msg.content) for msg in self.messages], + "sep": self.sep + } + + +class StoppingCriteriaSub(StoppingCriteria): + def __init__(self, stops=[], encounters=1): + super().__init__() + self.stops = stops + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + for stop in self.stops: + if torch.all((stop == input_ids[0][-len(stop):])).item(): + return True + + return False + + +CONV_VISION = Conversation( + system="Give the following image: ImageContent. " + "You will be able to see the image once I provide it to you. Please answer my questions.", + messages=[], + sep="###", +) + + +# TODO: If needed and possible, rewrite this file and re-organize the definition of components. + + +class Chat: + def __init__(self, + model: nn.Module, + processors: Dict[str, BaseProcessor], + device: str = 'cuda:0' + ): + self.device = device + self.model = model + self.processors = processors + stop_words_ids = [torch.tensor([835]).to(self.device), + torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. + self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) + self.just_uploaded = False + + def ask(self, text, conversation): + # NOTE: the hard code for postfix is removed. + # end_token = '' + # if len(conversation.messages) > 0 and conversation.messages[-1].role == Roles.HUMAN \ + # and conversation.messages[-1].content[-len(end_token):] == end_token: + if self.just_uploaded: + conversation.messages[-1].content = ' '.join([conversation.messages[-1].content, text]) + self.just_uploaded = False + else: + conversation.append_message(Roles.HUMAN, text) + + def answer(self, conversation, emb_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): + # Generate an answer written by LLaMA + conversation.append_message(Roles.ASSISTANT, None) + embs = self.get_context_emb(conversation, emb_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print('Warning: The number of tokens in current conversation exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, current_max_len - max_length) + + embs = embs[:, begin_idx:] + + outputs = self.model.llama_model.generate( + inputs_embeds=embs, + max_new_tokens=max_new_tokens, + stopping_criteria=self.stopping_criteria, + num_beams=num_beams, + do_sample=True, + min_length=min_length, + top_p=top_p, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + temperature=temperature, + ) + output_token = outputs[0] + if output_token[0] == 0: # the model might output a unknown token at the beginning. remove it + output_token = output_token[1:] + if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it + output_token = output_token[1:] + output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) + output_text = output_text.split('###')[0] # remove the stop sign '###' + output_text = output_text.split('Assistant:')[-1].strip() + conversation.messages[-1].content = output_text + return output_text, output_token.cpu().numpy() + + def upload_img(self, image: Union[str, Image.Image, Tensor], conversation: Conversation, emb_list: List[Tensor]): + # Upload Image, Encode Image and Create a new message from human. + image = load_image(image, self.processors[ModalityType.VISION]).to(self.device) + all_embeddings = self.model.encode_inputs({ModalityType.VISION: image}) + image_emb = all_embeddings[ModalityType.VISION] + emb_list.append(image_emb) + conversation.append_message(Roles.HUMAN, "") + self.just_uploaded = True + + def get_context_emb(self, conversation: Conversation, emb_list: List[Tensor]): + # Insert the embeddings into the prompts and queries. + # NOTE: Assume the placeholders have been aligned to the embeddings! + prompt = conversation.get_prompt() + prompt_segs = prompt.split('') + assert len(prompt_segs) == len(emb_list) + 1, "Unmatched numbers of placeholders and embeddings." + seg_tokens = [ + self.model.llama_tokenizer( + seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids + # only add bos to the first seg + for i, seg in enumerate(prompt_segs) + ] + seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] + mixed_embs = [emb for pair in zip(seg_embs[:-1], emb_list) for emb in pair] + [seg_embs[-1]] + mixed_embs = torch.cat(mixed_embs, dim=1) + return mixed_embs diff --git a/eval_scripts/eval_utils.py b/eval_scripts/eval_utils.py new file mode 100644 index 0000000..45aec39 --- /dev/null +++ b/eval_scripts/eval_utils.py @@ -0,0 +1,15 @@ +import torch +from PIL import Image + + +def load_image(image, image_processor): + if isinstance(image, str): # is a image path + raw_image = Image.open(image).convert('RGB') + image = image_processor(raw_image).unsqueeze(0) + elif isinstance(image, Image.Image): + raw_image = image + image = image_processor(raw_image).unsqueeze(0) + elif isinstance(image, torch.Tensor): + if len(image.shape) == 3: + image = image.unsqueeze(0) + return image diff --git a/demo.py b/eval_scripts/qualitative_eval.py similarity index 72% rename from demo.py rename to eval_scripts/qualitative_eval.py index d441571..ebc5012 100644 --- a/demo.py +++ b/eval_scripts/qualitative_eval.py @@ -1,5 +1,4 @@ import argparse -import os import random import numpy as np @@ -10,18 +9,16 @@ import gradio as gr from minigpt4.common.config import Config from minigpt4.common.dist_utils import get_rank from minigpt4.common.registry import registry -from minigpt4.conversation.conversation import Chat, CONV_VISION +from eval_scripts.conversation import Chat, CONV_VISION +# NOTE&TODO: put this before minigpt4 import will cause circular import +# possibly because `imagebind` imports `minigpt4` and `minigpt4` also imports `imagebind` +from imagebind.models.image_bind import ModalityType # imports modules for registration -from minigpt4.datasets.builders import * -from minigpt4.models import * -from minigpt4.processors import * -from minigpt4.runners import * -from minigpt4.tasks import * def parse_args(): - parser = argparse.ArgumentParser(description="Demo") + parser = argparse.ArgumentParser(description="Qualitative") parser.add_argument("--cfg-path", required=True, help="path to configuration file.") parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") parser.add_argument( @@ -59,9 +56,11 @@ model_config.device_8bit = args.gpu_id model_cls = registry.get_model_class(model_config.arch) model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) -vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train +# TODO: Fix hard-coding `cc12m` +vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) -chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) +processors = {ModalityType.VISION: vis_processor} +chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id)) print('Initialization Finished') @@ -69,24 +68,27 @@ print('Initialization Finished') # Gradio Setting # ======================================== -def gradio_reset(chat_state, img_list): +def gradio_reset(chat_state, emb_list): if chat_state is not None: chat_state.messages = [] - if img_list is not None: - img_list = [] - return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first', - interactive=False), gr.update( - value="Upload & Start Chat", interactive=True), chat_state, img_list + if emb_list is not None: + emb_list = [] + return None, gr.update(value=None, interactive=True), \ + gr.update(placeholder='Please upload your image first', interactive=False), \ + gr.update(value="Upload & Start Chat", interactive=True), \ + chat_state, emb_list def upload_img(gr_img, text_input, chat_state): if gr_img is None: return None, None, gr.update(interactive=True), chat_state, None chat_state = CONV_VISION.copy() - img_list = [] - llm_message = chat.upload_img(gr_img, chat_state, img_list) - return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update( - value="Start Chatting", interactive=False), chat_state, img_list + emb_list = [] + chat.upload_img(gr_img, chat_state, emb_list) + return gr.update(interactive=False), \ + gr.update(interactive=True, placeholder='Type and press Enter'), \ + gr.update(value="Start Chatting", interactive=False), \ + chat_state, emb_list def gradio_ask(user_message, chatbot, chat_state): @@ -97,15 +99,15 @@ def gradio_ask(user_message, chatbot, chat_state): return '', chatbot, chat_state -def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): - llm_message = chat.answer(conv=chat_state, - img_list=img_list, +def gradio_answer(chatbot, chat_state, emb_list, num_beams, temperature): + llm_message = chat.answer(conversation=chat_state, + emb_list=emb_list, num_beams=num_beams, temperature=temperature, max_new_tokens=300, max_length=2000)[0] chatbot[-1][1] = llm_message - return chatbot, chat_state, img_list + return chatbot, chat_state, emb_list title = """

Demo of BindGPT-4

""" @@ -146,17 +148,17 @@ with gr.Blocks() as demo: with gr.Column(): chat_state = gr.State() - img_list = gr.State() + emb_list = gr.State() chatbot = gr.Chatbot(label='BindGPT-4') text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False) upload_button.click(upload_img, [image, text_input, chat_state], - [image, text_input, upload_button, chat_state, img_list]) + [image, text_input, upload_button, chat_state, emb_list]) text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then( - gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list] + gradio_answer, [chatbot, chat_state, emb_list, num_beams, temperature], [chatbot, chat_state, emb_list] ) - clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], + clear.click(gradio_reset, [chat_state, emb_list], [chatbot, image, text_input, upload_button, chat_state, emb_list], queue=False) demo.launch(share=True, enable_queue=True) diff --git a/eval_scripts/quantitative_eval.py b/eval_scripts/quantitative_eval.py new file mode 100644 index 0000000..8a8285e --- /dev/null +++ b/eval_scripts/quantitative_eval.py @@ -0,0 +1,82 @@ +import argparse +import json +import os + +import shortuuid +from tqdm import tqdm + +from minigpt4.common.config import Config +from minigpt4.common.registry import registry +# TODO: check the circular import problem in `eval_scripts.conversation` +from eval_scripts.conversation import Chat, CONV_VISION +from imagebind.models.image_bind import ModalityType + + +def parse_args(): + parser = argparse.ArgumentParser(description="Quantitative") + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + parser.add_argument("--question-path", required=True, help="path to question file.") + parser.add_argument("--answer-path", required=True, help="path to answer result file.") + parser.add_argument("--image-folder", required=True, help="path to the image queries (COCO 2014 val).") + args = parser.parse_args() + return args + + +# ======================================== +# Model Initialization +# ======================================== +print('Initializing Chat') +args = parse_args() +cfg = Config(args) + +model_config = cfg.model_cfg +model_config.device_8bit = args.gpu_id +model_cls = registry.get_model_class(model_config.arch) +model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) + +# TODO: fix hard-coding `cc12m` +vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train +vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) +processors = {ModalityType.VISION: vis_processor} +chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id)) +print('Initialization Finished') + +# ======================================== +# Prompt Setting +# ======================================== +prompt = "Give the following image: ImageContent. " \ + "You will be able to see the image once I provide it to you. Please answer my question." + +# ======================================== +# Question Loading +# ======================================== +import pdb; pdb.set_trace() +questions = [json.loads(q) for q in open(args.question_path, "r")] +answer_file = open(args.answer_path, "w") +for i, line in enumerate(tqdm(questions)): + idx = line["question_id"] + image_file = os.path.join(args.image_folder, "COCO_val2014_" + line["image"]) + question = line["text"] + state = CONV_VISION.copy() + emb_list = [] + chat.upload_img(image_file, state, emb_list) + chat.ask(question, state) + answer, _ = chat.answer(state, emb_list) + ans_id = shortuuid.uuid() + answer_file.write(json.dumps({"question_id": idx, + "prompt": question, + "text": answer, + "answer_id": ans_id, + "model_id": model_config.arch, + "metadata": {}}) + "\n") + answer_file.flush() +answer_file.close() + diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py deleted file mode 100644 index 86d8b80..0000000 --- a/minigpt4/conversation/conversation.py +++ /dev/null @@ -1,211 +0,0 @@ -import dataclasses -from enum import auto, Enum -from typing import List, Any - -import torch -from PIL import Image -from transformers import StoppingCriteria, StoppingCriteriaList - -from imagebind.models.image_bind import ModalityType - - -class SeparatorStyle(Enum): - """Different separator style.""" - SINGLE = auto() - TWO = auto() - - -@dataclasses.dataclass -class Conversation: - """A class that keeps all conversation history.""" - system: str - roles: List[str] - messages: List[List[str]] - offset: int - # system_img: List[Image.Image] = [] - sep_style: SeparatorStyle = SeparatorStyle.SINGLE - sep: str = "###" - sep2: str = None - - skip_next: bool = False - conv_id: Any = None - - def get_prompt(self): - if self.sep_style == SeparatorStyle.SINGLE: - ret = self.system + self.sep - for role, message in self.messages: - if message: - ret += role + ": " + message + self.sep - else: - ret += role + ":" - return ret - elif self.sep_style == SeparatorStyle.TWO: - seps = [self.sep, self.sep2] - ret = self.system + seps[0] - for i, (role, message) in enumerate(self.messages): - if message: - ret += role + ": " + message + seps[i % 2] - else: - ret += role + ":" - return ret - else: - raise ValueError(f"Invalid style: {self.sep_style}") - - def append_message(self, role, message): - self.messages.append([role, message]) - - def to_gradio_chatbot(self): - ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset:]): - if i % 2 == 0: - ret.append([msg, None]) - else: - ret[-1][-1] = msg - return ret - - def copy(self): - return Conversation( - system=self.system, - # system_img=self.system_img, - roles=self.roles, - messages=[[x, y] for x, y in self.messages], - offset=self.offset, - sep_style=self.sep_style, - sep=self.sep, - sep2=self.sep2, - conv_id=self.conv_id) - - def dict(self): - return { - "system": self.system, - # "system_img": self.system_img, - "roles": self.roles, - "messages": self.messages, - "offset": self.offset, - "sep": self.sep, - "sep2": self.sep2, - "conv_id": self.conv_id, - } - - -class StoppingCriteriaSub(StoppingCriteria): - - def __init__(self, stops=[], encounters=1): - super().__init__() - self.stops = stops - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): - for stop in self.stops: - if torch.all((stop == input_ids[0][-len(stop):])).item(): - return True - - return False - - -CONV_VISION = Conversation( - system="Give the following image: ImageContent. " - "You will be able to see the image once I provide it to you. Please answer my questions.", - roles=("Human", "Assistant"), - messages=[], - offset=2, - sep_style=SeparatorStyle.SINGLE, - sep="###", -) - -# TODO: If needed and possible, rewrite this file and re-organize the definition of components. - - -class Chat: - def __init__(self, model, vis_processor, device='cuda:0'): - self.device = device - self.model = model - self.vis_processor = vis_processor - stop_words_ids = [torch.tensor([835]).to(self.device), - torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways. - self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) - - def ask(self, text, conv): - # NOTE: the hard code for postfix is removed. - # TODO: Need to be compatible with more modalities. - end_token = '' - if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \ - and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image. - conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text]) - else: - conv.append_message(conv.roles[0], text) - - def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, - repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): - # Generate an answer written by LLaMA - conv.append_message(conv.roles[1], None) - embs = self.get_context_emb(conv, img_list) - - current_max_len = embs.shape[1] + max_new_tokens - if current_max_len - max_length > 0: - print('Warning: The number of tokens in current conversation exceeds the max length. ' - 'The model will not see the contexts outside the range.') - begin_idx = max(0, current_max_len - max_length) - - embs = embs[:, begin_idx:] - - outputs = self.model.llama_model.generate( - inputs_embeds=embs, - max_new_tokens=max_new_tokens, - stopping_criteria=self.stopping_criteria, - num_beams=num_beams, - do_sample=True, - min_length=min_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - temperature=temperature, - ) - output_token = outputs[0] - if output_token[0] == 0: # the model might output a unknown token at the beginning. remove it - output_token = output_token[1:] - if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it - output_token = output_token[1:] - output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False) - output_text = output_text.split('###')[0] # remove the stop sign '###' - output_text = output_text.split('Assistant:')[-1].strip() - conv.messages[-1][1] = output_text - return output_text, output_token.cpu().numpy() - - def upload_img(self, image, conv, img_list): - # Upload Image, Encode Image and Create a new message from human. - if isinstance(image, str): # is a image path - raw_image = Image.open(image).convert('RGB') - image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) - elif isinstance(image, Image.Image): - raw_image = image - image = self.vis_processor(raw_image).unsqueeze(0).to(self.device) - elif isinstance(image, torch.Tensor): - if len(image.shape) == 3: - image = image.unsqueeze(0) - image = image.to(self.device) - - all_embeddings = self.model.encode_inputs({ModalityType.VISION: image}) - image_emb = all_embeddings[ModalityType.VISION] - img_list.append(image_emb) - conv.append_message(conv.roles[0], "") - msg = "Received." - # self.conv.append_message(self.conv.roles[1], msg) - return msg - - def get_context_emb(self, conv, img_list): - # Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor] - prompt = conv.get_prompt() - prompt_segs = prompt.split('') - assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images." - seg_tokens = [ - self.model.llama_tokenizer( - seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids - # only add bos to the first seg - for i, seg in enumerate(prompt_segs) - ] - seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens] - mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] - mixed_embs = torch.cat(mixed_embs, dim=1) - return mixed_embs - - diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py index b400c7e..59c60f6 100644 --- a/minigpt4/models/bind_gpt4.py +++ b/minigpt4/models/bind_gpt4.py @@ -45,15 +45,6 @@ class BindGPT4(BaseModel): self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind) print('Loading ImageBind Done') - print('Loading Q-Former and Adapter/Projector') - self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token, - vision_qformer_frozen=freeze_qformer, - vision_post_dims=[768, self.llama_model.config.hidden_size] - # vision_qformer_model=q_former_model, - # vision_pre_dims=(1280, 1408) - ) - print('Loading Q-Former and Adapter/Projector Done') - print('Loading LLAMA') self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token @@ -67,6 +58,15 @@ class BindGPT4(BaseModel): param.requires_grad = False print('Loading LLAMA Done') + print('Loading Q-Former and Adapter/Projector') + self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token, + vision_qformer_frozen=freeze_qformer, + vision_post_dims=[768, self.llama_model.config.hidden_size] + # vision_qformer_model=q_former_model, + # vision_pre_dims=(1280, 1408) + ) + print('Loading Q-Former and Adapter/Projector Done') + self.max_txt_len = max_txt_len self.end_sym = end_sym @@ -93,7 +93,7 @@ class BindGPT4(BaseModel): attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device) if prompt: batch_size = input_embeds.shape[0] - p_before, p_after = prompt.split('<{}Here>'.format(modality_name.title())) + p_before, p_after = prompt.split('') p_before_tokens = self.llama_tokenizer( p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device) p_after_tokens = self.llama_tokenizer( diff --git a/prompts/alignment.txt b/prompts/alignment.txt index 90ae57b..28137af 100644 --- a/prompts/alignment.txt +++ b/prompts/alignment.txt @@ -1,4 +1,4 @@ - Describe this image in detail. - Take a look at this image and describe what you notice. - Please provide a detailed description of the picture. - Could you describe the contents of this image for me? \ No newline at end of file + Describe this image in detail. + Take a look at this image and describe what you notice. + Please provide a detailed description of the picture. + Could you describe the contents of this image for me? \ No newline at end of file diff --git a/temp.py b/temp.py deleted file mode 100644 index e69de29..0000000 diff --git a/train_configs/bindgpt4_stage1_pretrain.yaml b/train_configs/bindgpt4_stage1_pretrain.yaml index 1f9aad1..f0a520a 100644 --- a/train_configs/bindgpt4_stage1_pretrain.yaml +++ b/train_configs/bindgpt4_stage1_pretrain.yaml @@ -1,7 +1,7 @@ model: arch: bind_gpt4 model_type: pretrain_vicuna - freeze_vit: True + freeze_imagebind: True freeze_qformer: False