diff --git a/cli_demo.py b/cli_demo.py new file mode 100644 index 0000000..69aff98 --- /dev/null +++ b/cli_demo.py @@ -0,0 +1,103 @@ +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import gradio as gr + +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank +from minigpt4.common.registry import registry +from minigpt4.conversation.conversation import Chat, CONV_VISION + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") + parser.add_argument("--num-beams", type=int, default=2, help="specify the gpu to load the model.") + parser.add_argument("--temperature", type=int, default=0.9, help="specify the gpu to load the model.") + parser.add_argument("--english", type=bool, default=True, help="chinese or english") + parser.add_argument("--prompt-en", type=str, default="can you describe the current picture?", help="Can you describe the current picture?") + parser.add_argument("--prompt-zh", type=str, default="你能描述一下当前的图片?", help="Can you describe the current picture?") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +# ======================================== +# Model Initialization +# ======================================== + +print('Initializing Chat') +args = parse_args() +cfg = Config(args) + +model_config = cfg.model_cfg +model_config.device_8bit = args.gpu_id +model_cls = registry.get_model_class(model_config.arch) +model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) + +vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train +vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) +chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) +print('Initialization Finished') + +while True: + if not args.english: + image_path = input("请输入图像路径或URL(回车进入纯文本对话): ") + else: + image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ") + + if image_path == 'stop': + break + if len(image_path) > 0: + query = args.prompt_en if args.english else args.prompt_zh + + + while True: + if query == "clear": + break + if query == "stop": + sys.exit(0) + img_list = [] + chat_state = CONV_VISION.copy() + chat.upload_img(image_path, chat_state, img_list) + chat.ask(query, chat_state) + llm_message = chat.answer( + conv=chat_state, + img_list=img_list, + num_beams=args.num_beams, + temperature=args.temperature, + max_new_tokens=300, + max_length=2000 + )[0] + # chatbot[-1][1] = llm_message + print("MiniGPT4:"+llm_message) + query = input("user:")