diff --git a/demo.py b/demo.py index b3659f1..7ee1f73 100644 --- a/demo.py +++ b/demo.py @@ -57,11 +57,11 @@ cfg = Config(args) model_config = cfg.model_cfg model_config.device_8bit = args.gpu_id model_cls = registry.get_model_class(model_config.arch) -model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id)) +model = model_cls.from_config(model_config).to('cpu') vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) -chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id)) +chat = Chat(model, vis_processor, device='cpu') print('Initialization Finished') # ========================================