mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 10:30:45 +00:00
add argument to switch 8bit
This commit is contained in:
parent
3e03c8327f
commit
b76d5c5a50
@ -5,6 +5,7 @@ model:
|
||||
freeze_qformer: True
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
low_resource: True
|
||||
prompt_path: "prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: '/path/to/pretrained/ckpt/'
|
||||
|
@ -36,11 +36,13 @@ class MiniGPT4(Blip2Base):
|
||||
prompt_path="",
|
||||
prompt_template="",
|
||||
max_txt_len=32,
|
||||
low_resource=False, # use 8 bit and put vit in cpu
|
||||
end_sym='\n',
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = self.init_tokenizer()
|
||||
self.low_resource = low_resource
|
||||
|
||||
print('Loading VIT')
|
||||
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
||||
@ -83,10 +85,19 @@ class MiniGPT4(Blip2Base):
|
||||
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
|
||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
||||
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model, torch_dtype=torch.float16,
|
||||
load_in_8bit=True, device_map="auto"
|
||||
)
|
||||
if self.low_resource:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_8bit=True,
|
||||
device_map="auto"
|
||||
)
|
||||
else:
|
||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||
llama_model,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
for name, param in self.llama_model.named_parameters():
|
||||
param.requires_grad = False
|
||||
print('Loading LLAMA Done')
|
||||
@ -107,18 +118,22 @@ class MiniGPT4(Blip2Base):
|
||||
else:
|
||||
self.prompt_list = []
|
||||
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
def vit_to_cpu(self):
|
||||
self.ln_vision.to("cpu")
|
||||
self.ln_vision.float()
|
||||
self.visual_encoder.to("cpu")
|
||||
self.visual_encoder.float()
|
||||
image = image.to("cpu")
|
||||
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||
def encode_img(self, image):
|
||||
device = image.device
|
||||
if self.low_resource:
|
||||
self.vit_to_cpu()
|
||||
image = image.to("cpu")
|
||||
|
||||
with self.maybe_autocast():
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||
|
||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||
query_output = self.Qformer.bert(
|
||||
query_embeds=query_tokens,
|
||||
@ -216,6 +231,7 @@ class MiniGPT4(Blip2Base):
|
||||
vit_precision = cfg.get("vit_precision", "fp16")
|
||||
freeze_vit = cfg.get("freeze_vit", True)
|
||||
freeze_qformer = cfg.get("freeze_qformer", True)
|
||||
low_resource = cfg.get("low_resource", False)
|
||||
|
||||
prompt_path = cfg.get("prompt_path", "")
|
||||
prompt_template = cfg.get("prompt_template", "")
|
||||
@ -236,6 +252,7 @@ class MiniGPT4(Blip2Base):
|
||||
prompt_path=prompt_path,
|
||||
prompt_template=prompt_template,
|
||||
max_txt_len=max_txt_len,
|
||||
low_resource=low_resource,
|
||||
end_sym=end_sym
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user