diff --git a/README.md b/README.md
index 7aa29f2..b1f8961 100644
--- a/README.md
+++ b/README.md
@@ -87,7 +87,7 @@ in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Lin
### Launching Demo Locally
-Try out our demo [demo.py](demo.py) on your local machine by running
+Try out our demo [demo.py](eval_scripts/qualitative_eval.py) on your local machine by running
```
python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0
diff --git a/eval_configs/bindgpt4_eval.yaml b/eval_configs/bindgpt4_eval.yaml
index 73d24e5..2b3282b 100644
--- a/eval_configs/bindgpt4_eval.yaml
+++ b/eval_configs/bindgpt4_eval.yaml
@@ -1,18 +1,18 @@
model:
arch: bind_gpt4
model_type: pretrain_vicuna
- freeze_vit: True
+ freeze_imagebind: True
freeze_qformer: False
max_txt_len: 160
end_sym: "###"
low_resource: False
prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
- ckpt: '/path/to/pretrained/ckpt/'
+ ckpt: 'minigpt4/output/minigpt4_stage1_pretrain/20230524192/checkpoint_0.pth'
datasets:
- cc_sbu_align:
+ cc12m: # Double check
vis_processor:
train:
name: "imagebind_vision_eval"
diff --git a/minigpt4/conversation/__init__.py b/eval_scripts/__init__.py
similarity index 100%
rename from minigpt4/conversation/__init__.py
rename to eval_scripts/__init__.py
diff --git a/eval_scripts/conversation.py b/eval_scripts/conversation.py
new file mode 100644
index 0000000..253950a
--- /dev/null
+++ b/eval_scripts/conversation.py
@@ -0,0 +1,170 @@
+import dataclasses
+from copy import deepcopy
+from types import SimpleNamespace
+from typing import List, Union, Dict
+
+import torch
+from PIL import Image
+from torch import nn, Tensor
+from transformers import StoppingCriteria, StoppingCriteriaList
+
+from eval_scripts.eval_utils import load_image
+from imagebind.models.image_bind import ModalityType
+from minigpt4 import BaseProcessor
+
+Roles = SimpleNamespace(
+ HUMAN="Human",
+ ASSISTANT="Assistant"
+)
+
+
+class Message:
+ def __init__(self, role: str, content: Union[str, None]):
+ self.role = role
+ self.content = content
+
+
+@dataclasses.dataclass
+class Conversation:
+ """A class that keeps all conversation history."""
+ system: str
+ messages: List[Message]
+ sep: str = "###"
+
+ def get_prompt(self):
+ ret = self.system + self.sep
+ for message in self.messages:
+ if message.content:
+ ret += message.role + ": " + message.content + self.sep
+ else:
+ ret += message.role + ":"
+ return ret
+
+ def append_message(self, role, content):
+ self.messages.append(Message(role, content))
+
+ def copy(self):
+ return Conversation(
+ system=self.system,
+ messages=deepcopy(self.messages),
+ sep=self.sep)
+
+ def dict(self):
+ return {
+ "system": self.system,
+ "messages": [(msg.role, msg.content) for msg in self.messages],
+ "sep": self.sep
+ }
+
+
+class StoppingCriteriaSub(StoppingCriteria):
+ def __init__(self, stops=[], encounters=1):
+ super().__init__()
+ self.stops = stops
+
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
+ for stop in self.stops:
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
+ return True
+
+ return False
+
+
+CONV_VISION = Conversation(
+ system="Give the following image: ImageContent. "
+ "You will be able to see the image once I provide it to you. Please answer my questions.",
+ messages=[],
+ sep="###",
+)
+
+
+# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
+
+
+class Chat:
+ def __init__(self,
+ model: nn.Module,
+ processors: Dict[str, BaseProcessor],
+ device: str = 'cuda:0'
+ ):
+ self.device = device
+ self.model = model
+ self.processors = processors
+ stop_words_ids = [torch.tensor([835]).to(self.device),
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
+ self.just_uploaded = False
+
+ def ask(self, text, conversation):
+ # NOTE: the hard code for postfix is removed.
+ # end_token = ''
+ # if len(conversation.messages) > 0 and conversation.messages[-1].role == Roles.HUMAN \
+ # and conversation.messages[-1].content[-len(end_token):] == end_token:
+ if self.just_uploaded:
+ conversation.messages[-1].content = ' '.join([conversation.messages[-1].content, text])
+ self.just_uploaded = False
+ else:
+ conversation.append_message(Roles.HUMAN, text)
+
+ def answer(self, conversation, emb_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
+ # Generate an answer written by LLaMA
+ conversation.append_message(Roles.ASSISTANT, None)
+ embs = self.get_context_emb(conversation, emb_list)
+
+ current_max_len = embs.shape[1] + max_new_tokens
+ if current_max_len - max_length > 0:
+ print('Warning: The number of tokens in current conversation exceeds the max length. '
+ 'The model will not see the contexts outside the range.')
+ begin_idx = max(0, current_max_len - max_length)
+
+ embs = embs[:, begin_idx:]
+
+ outputs = self.model.llama_model.generate(
+ inputs_embeds=embs,
+ max_new_tokens=max_new_tokens,
+ stopping_criteria=self.stopping_criteria,
+ num_beams=num_beams,
+ do_sample=True,
+ min_length=min_length,
+ top_p=top_p,
+ repetition_penalty=repetition_penalty,
+ length_penalty=length_penalty,
+ temperature=temperature,
+ )
+ output_token = outputs[0]
+ if output_token[0] == 0: # the model might output a unknown token at the beginning. remove it
+ output_token = output_token[1:]
+ if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
+ output_token = output_token[1:]
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
+ output_text = output_text.split('Assistant:')[-1].strip()
+ conversation.messages[-1].content = output_text
+ return output_text, output_token.cpu().numpy()
+
+ def upload_img(self, image: Union[str, Image.Image, Tensor], conversation: Conversation, emb_list: List[Tensor]):
+ # Upload Image, Encode Image and Create a new message from human.
+ image = load_image(image, self.processors[ModalityType.VISION]).to(self.device)
+ all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
+ image_emb = all_embeddings[ModalityType.VISION]
+ emb_list.append(image_emb)
+ conversation.append_message(Roles.HUMAN, "")
+ self.just_uploaded = True
+
+ def get_context_emb(self, conversation: Conversation, emb_list: List[Tensor]):
+ # Insert the embeddings into the prompts and queries.
+ # NOTE: Assume the placeholders have been aligned to the embeddings!
+ prompt = conversation.get_prompt()
+ prompt_segs = prompt.split('')
+ assert len(prompt_segs) == len(emb_list) + 1, "Unmatched numbers of placeholders and embeddings."
+ seg_tokens = [
+ self.model.llama_tokenizer(
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
+ # only add bos to the first seg
+ for i, seg in enumerate(prompt_segs)
+ ]
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
+ mixed_embs = [emb for pair in zip(seg_embs[:-1], emb_list) for emb in pair] + [seg_embs[-1]]
+ mixed_embs = torch.cat(mixed_embs, dim=1)
+ return mixed_embs
diff --git a/eval_scripts/eval_utils.py b/eval_scripts/eval_utils.py
new file mode 100644
index 0000000..45aec39
--- /dev/null
+++ b/eval_scripts/eval_utils.py
@@ -0,0 +1,15 @@
+import torch
+from PIL import Image
+
+
+def load_image(image, image_processor):
+ if isinstance(image, str): # is a image path
+ raw_image = Image.open(image).convert('RGB')
+ image = image_processor(raw_image).unsqueeze(0)
+ elif isinstance(image, Image.Image):
+ raw_image = image
+ image = image_processor(raw_image).unsqueeze(0)
+ elif isinstance(image, torch.Tensor):
+ if len(image.shape) == 3:
+ image = image.unsqueeze(0)
+ return image
diff --git a/demo.py b/eval_scripts/qualitative_eval.py
similarity index 72%
rename from demo.py
rename to eval_scripts/qualitative_eval.py
index d441571..ebc5012 100644
--- a/demo.py
+++ b/eval_scripts/qualitative_eval.py
@@ -1,5 +1,4 @@
import argparse
-import os
import random
import numpy as np
@@ -10,18 +9,16 @@ import gradio as gr
from minigpt4.common.config import Config
from minigpt4.common.dist_utils import get_rank
from minigpt4.common.registry import registry
-from minigpt4.conversation.conversation import Chat, CONV_VISION
+from eval_scripts.conversation import Chat, CONV_VISION
+# NOTE&TODO: put this before minigpt4 import will cause circular import
+# possibly because `imagebind` imports `minigpt4` and `minigpt4` also imports `imagebind`
+from imagebind.models.image_bind import ModalityType
# imports modules for registration
-from minigpt4.datasets.builders import *
-from minigpt4.models import *
-from minigpt4.processors import *
-from minigpt4.runners import *
-from minigpt4.tasks import *
def parse_args():
- parser = argparse.ArgumentParser(description="Demo")
+ parser = argparse.ArgumentParser(description="Qualitative")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
@@ -59,9 +56,11 @@ model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
-vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
+# TODO: Fix hard-coding `cc12m`
+vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
-chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
+processors = {ModalityType.VISION: vis_processor}
+chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')
@@ -69,24 +68,27 @@ print('Initialization Finished')
# Gradio Setting
# ========================================
-def gradio_reset(chat_state, img_list):
+def gradio_reset(chat_state, emb_list):
if chat_state is not None:
chat_state.messages = []
- if img_list is not None:
- img_list = []
- return None, gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your image first',
- interactive=False), gr.update(
- value="Upload & Start Chat", interactive=True), chat_state, img_list
+ if emb_list is not None:
+ emb_list = []
+ return None, gr.update(value=None, interactive=True), \
+ gr.update(placeholder='Please upload your image first', interactive=False), \
+ gr.update(value="Upload & Start Chat", interactive=True), \
+ chat_state, emb_list
def upload_img(gr_img, text_input, chat_state):
if gr_img is None:
return None, None, gr.update(interactive=True), chat_state, None
chat_state = CONV_VISION.copy()
- img_list = []
- llm_message = chat.upload_img(gr_img, chat_state, img_list)
- return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(
- value="Start Chatting", interactive=False), chat_state, img_list
+ emb_list = []
+ chat.upload_img(gr_img, chat_state, emb_list)
+ return gr.update(interactive=False), \
+ gr.update(interactive=True, placeholder='Type and press Enter'), \
+ gr.update(value="Start Chatting", interactive=False), \
+ chat_state, emb_list
def gradio_ask(user_message, chatbot, chat_state):
@@ -97,15 +99,15 @@ def gradio_ask(user_message, chatbot, chat_state):
return '', chatbot, chat_state
-def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
- llm_message = chat.answer(conv=chat_state,
- img_list=img_list,
+def gradio_answer(chatbot, chat_state, emb_list, num_beams, temperature):
+ llm_message = chat.answer(conversation=chat_state,
+ emb_list=emb_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=300,
max_length=2000)[0]
chatbot[-1][1] = llm_message
- return chatbot, chat_state, img_list
+ return chatbot, chat_state, emb_list
title = """Demo of BindGPT-4
"""
@@ -146,17 +148,17 @@ with gr.Blocks() as demo:
with gr.Column():
chat_state = gr.State()
- img_list = gr.State()
+ emb_list = gr.State()
chatbot = gr.Chatbot(label='BindGPT-4')
text_input = gr.Textbox(label='User', placeholder='Please upload your image first', interactive=False)
upload_button.click(upload_img, [image, text_input, chat_state],
- [image, text_input, upload_button, chat_state, img_list])
+ [image, text_input, upload_button, chat_state, emb_list])
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
- gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
+ gradio_answer, [chatbot, chat_state, emb_list, num_beams, temperature], [chatbot, chat_state, emb_list]
)
- clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list],
+ clear.click(gradio_reset, [chat_state, emb_list], [chatbot, image, text_input, upload_button, chat_state, emb_list],
queue=False)
demo.launch(share=True, enable_queue=True)
diff --git a/eval_scripts/quantitative_eval.py b/eval_scripts/quantitative_eval.py
new file mode 100644
index 0000000..8a8285e
--- /dev/null
+++ b/eval_scripts/quantitative_eval.py
@@ -0,0 +1,82 @@
+import argparse
+import json
+import os
+
+import shortuuid
+from tqdm import tqdm
+
+from minigpt4.common.config import Config
+from minigpt4.common.registry import registry
+# TODO: check the circular import problem in `eval_scripts.conversation`
+from eval_scripts.conversation import Chat, CONV_VISION
+from imagebind.models.image_bind import ModalityType
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Quantitative")
+ parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
+ parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
+ parser.add_argument(
+ "--options",
+ nargs="+",
+ help="override some settings in the used config, the key-value pair "
+ "in xxx=yyy format will be merged into config file (deprecate), "
+ "change to --cfg-options instead.",
+ )
+ parser.add_argument("--question-path", required=True, help="path to question file.")
+ parser.add_argument("--answer-path", required=True, help="path to answer result file.")
+ parser.add_argument("--image-folder", required=True, help="path to the image queries (COCO 2014 val).")
+ args = parser.parse_args()
+ return args
+
+
+# ========================================
+# Model Initialization
+# ========================================
+print('Initializing Chat')
+args = parse_args()
+cfg = Config(args)
+
+model_config = cfg.model_cfg
+model_config.device_8bit = args.gpu_id
+model_cls = registry.get_model_class(model_config.arch)
+model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
+
+# TODO: fix hard-coding `cc12m`
+vis_processor_cfg = cfg.datasets_cfg.cc12m.vis_processor.train
+vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
+processors = {ModalityType.VISION: vis_processor}
+chat = Chat(model, processors, device='cuda:{}'.format(args.gpu_id))
+print('Initialization Finished')
+
+# ========================================
+# Prompt Setting
+# ========================================
+prompt = "Give the following image: ImageContent. " \
+ "You will be able to see the image once I provide it to you. Please answer my question."
+
+# ========================================
+# Question Loading
+# ========================================
+import pdb; pdb.set_trace()
+questions = [json.loads(q) for q in open(args.question_path, "r")]
+answer_file = open(args.answer_path, "w")
+for i, line in enumerate(tqdm(questions)):
+ idx = line["question_id"]
+ image_file = os.path.join(args.image_folder, "COCO_val2014_" + line["image"])
+ question = line["text"]
+ state = CONV_VISION.copy()
+ emb_list = []
+ chat.upload_img(image_file, state, emb_list)
+ chat.ask(question, state)
+ answer, _ = chat.answer(state, emb_list)
+ ans_id = shortuuid.uuid()
+ answer_file.write(json.dumps({"question_id": idx,
+ "prompt": question,
+ "text": answer,
+ "answer_id": ans_id,
+ "model_id": model_config.arch,
+ "metadata": {}}) + "\n")
+ answer_file.flush()
+answer_file.close()
+
diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py
deleted file mode 100644
index 86d8b80..0000000
--- a/minigpt4/conversation/conversation.py
+++ /dev/null
@@ -1,211 +0,0 @@
-import dataclasses
-from enum import auto, Enum
-from typing import List, Any
-
-import torch
-from PIL import Image
-from transformers import StoppingCriteria, StoppingCriteriaList
-
-from imagebind.models.image_bind import ModalityType
-
-
-class SeparatorStyle(Enum):
- """Different separator style."""
- SINGLE = auto()
- TWO = auto()
-
-
-@dataclasses.dataclass
-class Conversation:
- """A class that keeps all conversation history."""
- system: str
- roles: List[str]
- messages: List[List[str]]
- offset: int
- # system_img: List[Image.Image] = []
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
- sep: str = "###"
- sep2: str = None
-
- skip_next: bool = False
- conv_id: Any = None
-
- def get_prompt(self):
- if self.sep_style == SeparatorStyle.SINGLE:
- ret = self.system + self.sep
- for role, message in self.messages:
- if message:
- ret += role + ": " + message + self.sep
- else:
- ret += role + ":"
- return ret
- elif self.sep_style == SeparatorStyle.TWO:
- seps = [self.sep, self.sep2]
- ret = self.system + seps[0]
- for i, (role, message) in enumerate(self.messages):
- if message:
- ret += role + ": " + message + seps[i % 2]
- else:
- ret += role + ":"
- return ret
- else:
- raise ValueError(f"Invalid style: {self.sep_style}")
-
- def append_message(self, role, message):
- self.messages.append([role, message])
-
- def to_gradio_chatbot(self):
- ret = []
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
- if i % 2 == 0:
- ret.append([msg, None])
- else:
- ret[-1][-1] = msg
- return ret
-
- def copy(self):
- return Conversation(
- system=self.system,
- # system_img=self.system_img,
- roles=self.roles,
- messages=[[x, y] for x, y in self.messages],
- offset=self.offset,
- sep_style=self.sep_style,
- sep=self.sep,
- sep2=self.sep2,
- conv_id=self.conv_id)
-
- def dict(self):
- return {
- "system": self.system,
- # "system_img": self.system_img,
- "roles": self.roles,
- "messages": self.messages,
- "offset": self.offset,
- "sep": self.sep,
- "sep2": self.sep2,
- "conv_id": self.conv_id,
- }
-
-
-class StoppingCriteriaSub(StoppingCriteria):
-
- def __init__(self, stops=[], encounters=1):
- super().__init__()
- self.stops = stops
-
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
- for stop in self.stops:
- if torch.all((stop == input_ids[0][-len(stop):])).item():
- return True
-
- return False
-
-
-CONV_VISION = Conversation(
- system="Give the following image: ImageContent. "
- "You will be able to see the image once I provide it to you. Please answer my questions.",
- roles=("Human", "Assistant"),
- messages=[],
- offset=2,
- sep_style=SeparatorStyle.SINGLE,
- sep="###",
-)
-
-# TODO: If needed and possible, rewrite this file and re-organize the definition of components.
-
-
-class Chat:
- def __init__(self, model, vis_processor, device='cuda:0'):
- self.device = device
- self.model = model
- self.vis_processor = vis_processor
- stop_words_ids = [torch.tensor([835]).to(self.device),
- torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
- self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
-
- def ask(self, text, conv):
- # NOTE: the hard code for postfix is removed.
- # TODO: Need to be compatible with more modalities.
- end_token = ''
- if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
- and conv.messages[-1][1][-len(end_token):] == end_token: # last message is image.
- conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
- else:
- conv.append_message(conv.roles[0], text)
-
- def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
- repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000):
- # Generate an answer written by LLaMA
- conv.append_message(conv.roles[1], None)
- embs = self.get_context_emb(conv, img_list)
-
- current_max_len = embs.shape[1] + max_new_tokens
- if current_max_len - max_length > 0:
- print('Warning: The number of tokens in current conversation exceeds the max length. '
- 'The model will not see the contexts outside the range.')
- begin_idx = max(0, current_max_len - max_length)
-
- embs = embs[:, begin_idx:]
-
- outputs = self.model.llama_model.generate(
- inputs_embeds=embs,
- max_new_tokens=max_new_tokens,
- stopping_criteria=self.stopping_criteria,
- num_beams=num_beams,
- do_sample=True,
- min_length=min_length,
- top_p=top_p,
- repetition_penalty=repetition_penalty,
- length_penalty=length_penalty,
- temperature=temperature,
- )
- output_token = outputs[0]
- if output_token[0] == 0: # the model might output a unknown token at the beginning. remove it
- output_token = output_token[1:]
- if output_token[0] == 1: # some users find that there is a start token at the beginning. remove it
- output_token = output_token[1:]
- output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
- output_text = output_text.split('###')[0] # remove the stop sign '###'
- output_text = output_text.split('Assistant:')[-1].strip()
- conv.messages[-1][1] = output_text
- return output_text, output_token.cpu().numpy()
-
- def upload_img(self, image, conv, img_list):
- # Upload Image, Encode Image and Create a new message from human.
- if isinstance(image, str): # is a image path
- raw_image = Image.open(image).convert('RGB')
- image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
- elif isinstance(image, Image.Image):
- raw_image = image
- image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
- elif isinstance(image, torch.Tensor):
- if len(image.shape) == 3:
- image = image.unsqueeze(0)
- image = image.to(self.device)
-
- all_embeddings = self.model.encode_inputs({ModalityType.VISION: image})
- image_emb = all_embeddings[ModalityType.VISION]
- img_list.append(image_emb)
- conv.append_message(conv.roles[0], "")
- msg = "Received."
- # self.conv.append_message(self.conv.roles[1], msg)
- return msg
-
- def get_context_emb(self, conv, img_list):
- # Insert the image embeddings into the prompts and queries. Note that the img_list: List[Tensor]
- prompt = conv.get_prompt()
- prompt_segs = prompt.split('')
- assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
- seg_tokens = [
- self.model.llama_tokenizer(
- seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
- # only add bos to the first seg
- for i, seg in enumerate(prompt_segs)
- ]
- seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
- mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
- mixed_embs = torch.cat(mixed_embs, dim=1)
- return mixed_embs
-
-
diff --git a/minigpt4/models/bind_gpt4.py b/minigpt4/models/bind_gpt4.py
index b400c7e..59c60f6 100644
--- a/minigpt4/models/bind_gpt4.py
+++ b/minigpt4/models/bind_gpt4.py
@@ -45,15 +45,6 @@ class BindGPT4(BaseModel):
self.multimodal_encoder = imagebind_huge(pretrained=True, freeze_imagebind=freeze_imagebind)
print('Loading ImageBind Done')
- print('Loading Q-Former and Adapter/Projector')
- self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
- vision_qformer_frozen=freeze_qformer,
- vision_post_dims=[768, self.llama_model.config.hidden_size]
- # vision_qformer_model=q_former_model,
- # vision_pre_dims=(1280, 1408)
- )
- print('Loading Q-Former and Adapter/Projector Done')
-
print('Loading LLAMA')
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
@@ -67,6 +58,15 @@ class BindGPT4(BaseModel):
param.requires_grad = False
print('Loading LLAMA Done')
+ print('Loading Q-Former and Adapter/Projector')
+ self.multimodal_joiner = ImageBindJoiner(vision_query_token_num=num_query_token,
+ vision_qformer_frozen=freeze_qformer,
+ vision_post_dims=[768, self.llama_model.config.hidden_size]
+ # vision_qformer_model=q_former_model,
+ # vision_pre_dims=(1280, 1408)
+ )
+ print('Loading Q-Former and Adapter/Projector Done')
+
self.max_txt_len = max_txt_len
self.end_sym = end_sym
@@ -93,7 +93,7 @@ class BindGPT4(BaseModel):
attns_input = torch.ones(input_embeds.size()[:-1], dtype=torch.long).to(input_embeds.device)
if prompt:
batch_size = input_embeds.shape[0]
- p_before, p_after = prompt.split('<{}Here>'.format(modality_name.title()))
+ p_before, p_after = prompt.split('')
p_before_tokens = self.llama_tokenizer(
p_before, return_tensors="pt", add_special_tokens=False).to(input_embeds.device)
p_after_tokens = self.llama_tokenizer(
diff --git a/prompts/alignment.txt b/prompts/alignment.txt
index 90ae57b..28137af 100644
--- a/prompts/alignment.txt
+++ b/prompts/alignment.txt
@@ -1,4 +1,4 @@
- Describe this image in detail.
- Take a look at this image and describe what you notice.
- Please provide a detailed description of the picture.
- Could you describe the contents of this image for me?
\ No newline at end of file
+ Describe this image in detail.
+ Take a look at this image and describe what you notice.
+ Please provide a detailed description of the picture.
+ Could you describe the contents of this image for me?
\ No newline at end of file
diff --git a/temp.py b/temp.py
deleted file mode 100644
index e69de29..0000000
diff --git a/train_configs/bindgpt4_stage1_pretrain.yaml b/train_configs/bindgpt4_stage1_pretrain.yaml
index 1f9aad1..f0a520a 100644
--- a/train_configs/bindgpt4_stage1_pretrain.yaml
+++ b/train_configs/bindgpt4_stage1_pretrain.yaml
@@ -1,7 +1,7 @@
model:
arch: bind_gpt4
model_type: pretrain_vicuna
- freeze_vit: True
+ freeze_imagebind: True
freeze_qformer: False