mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 18:40:46 +00:00
71 lines
2.2 KiB
Python
71 lines
2.2 KiB
Python
import argparse
|
|
import os
|
|
import random
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.backends.cudnn as cudnn
|
|
import gradio as gr
|
|
import json
|
|
|
|
from minigpt4.common.config import Config
|
|
from minigpt4.common.dist_utils import get_rank
|
|
from minigpt4.common.registry import registry
|
|
from minigpt4.conversation.response import Chat
|
|
|
|
# imports modules for registration
|
|
from minigpt4.datasets.builders import *
|
|
from minigpt4.models import *
|
|
from minigpt4.processors import *
|
|
from minigpt4.runners import *
|
|
from minigpt4.tasks import *
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Demo")
|
|
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
|
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
|
parser.add_argument(
|
|
"--options",
|
|
nargs="+",
|
|
help="override some settings in the used config, the key-value pair "
|
|
"in xxx=yyy format will be merged into config file (deprecate), "
|
|
"change to --cfg-options instead.",
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def setup_seeds(config):
|
|
seed = config.run_cfg.seed + get_rank()
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
cudnn.benchmark = False
|
|
cudnn.deterministic = True
|
|
|
|
|
|
args = parse_args()
|
|
cfg = Config(args)
|
|
|
|
model_config = cfg.model_cfg
|
|
model_config.device_8bit = args.gpu_id
|
|
model_cls = registry.get_model_class(model_config.arch)
|
|
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
|
|
|
vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
|
|
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
|
|
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
|
|
|
|
# print(model_config.output_path)
|
|
with open(model_config.output_path, 'r') as json_file:
|
|
for line in json_file:
|
|
item = json.loads(line)
|
|
# print(item["image"])
|
|
# print(item["text"])
|
|
image_emb = chat.upload_img(item["image"])
|
|
# [1, 32, 4096]
|
|
# print(image_emb.shape)
|
|
embedding = chat.get_context_emb(item["text"], image_emb)
|
|
llm_message = chat.answer(embs=embedding, max_new_tokens=300, max_length=2000)[0]
|
|
print(llm_message) |