mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-17 11:10:46 +00:00
commit
3e03c8327f
4
demo.py
4
demo.py
@ -113,8 +113,8 @@ with gr.Blocks() as demo:
|
|||||||
|
|
||||||
num_beams = gr.Slider(
|
num_beams = gr.Slider(
|
||||||
minimum=1,
|
minimum=1,
|
||||||
maximum=16,
|
maximum=10,
|
||||||
value=5,
|
value=1,
|
||||||
step=1,
|
step=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
label="beam search numbers)",
|
label="beam search numbers)",
|
||||||
|
@ -25,7 +25,7 @@ dependencies:
|
|||||||
- filelock==3.9.0
|
- filelock==3.9.0
|
||||||
- fonttools==4.38.0
|
- fonttools==4.38.0
|
||||||
- frozenlist==1.3.3
|
- frozenlist==1.3.3
|
||||||
- huggingface-hub==0.12.1
|
- huggingface-hub==0.13.4
|
||||||
- importlib-resources==5.12.0
|
- importlib-resources==5.12.0
|
||||||
- kiwisolver==1.4.4
|
- kiwisolver==1.4.4
|
||||||
- matplotlib==3.7.0
|
- matplotlib==3.7.0
|
||||||
|
@ -134,8 +134,8 @@ class Chat:
|
|||||||
else:
|
else:
|
||||||
conv.append_message(conv.roles[0], text)
|
conv.append_message(conv.roles[0], text)
|
||||||
|
|
||||||
def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
|
def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
|
||||||
repetition_penalty=1.0, length_penalty=1, temperature=1):
|
repetition_penalty=1.0, length_penalty=1, temperature=1.0):
|
||||||
conv.append_message(conv.roles[1], None)
|
conv.append_message(conv.roles[1], None)
|
||||||
embs = self.get_context_emb(conv, img_list)
|
embs = self.get_context_emb(conv, img_list)
|
||||||
outputs = self.model.llama_model.generate(
|
outputs = self.model.llama_model.generate(
|
||||||
@ -143,6 +143,7 @@ class Chat:
|
|||||||
max_new_tokens=max_new_tokens,
|
max_new_tokens=max_new_tokens,
|
||||||
stopping_criteria=self.stopping_criteria,
|
stopping_criteria=self.stopping_criteria,
|
||||||
num_beams=num_beams,
|
num_beams=num_beams,
|
||||||
|
do_sample=True,
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
repetition_penalty=repetition_penalty,
|
repetition_penalty=repetition_penalty,
|
||||||
|
@ -84,7 +84,8 @@ class MiniGPT4(Blip2Base):
|
|||||||
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
|
||||||
|
|
||||||
self.llama_model = LlamaForCausalLM.from_pretrained(
|
self.llama_model = LlamaForCausalLM.from_pretrained(
|
||||||
llama_model, torch_dtype=torch.float16
|
llama_model, torch_dtype=torch.float16,
|
||||||
|
load_in_8bit=True, device_map="auto"
|
||||||
)
|
)
|
||||||
for name, param in self.llama_model.named_parameters():
|
for name, param in self.llama_model.named_parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
@ -107,12 +108,17 @@ class MiniGPT4(Blip2Base):
|
|||||||
self.prompt_list = []
|
self.prompt_list = []
|
||||||
|
|
||||||
def encode_img(self, image):
|
def encode_img(self, image):
|
||||||
with self.maybe_autocast():
|
device = image.device
|
||||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
self.ln_vision.to("cpu")
|
||||||
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
self.ln_vision.float()
|
||||||
image.device
|
self.visual_encoder.to("cpu")
|
||||||
)
|
self.visual_encoder.float()
|
||||||
|
image = image.to("cpu")
|
||||||
|
|
||||||
|
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
|
||||||
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
|
||||||
|
|
||||||
|
with self.maybe_autocast():
|
||||||
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
||||||
query_output = self.Qformer.bert(
|
query_output = self.Qformer.bert(
|
||||||
query_embeds=query_tokens,
|
query_embeds=query_tokens,
|
||||||
|
Loading…
Reference in New Issue
Block a user