This commit is contained in:
Zhijie Lin 2023-05-25 15:35:59 +08:00
parent fef5e3640d
commit 5204a512e8

View File

@ -54,6 +54,7 @@ cfg = Config(args)
model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
print(model_config)
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
# TODO: Fix hard-coding `cc12m`