mirror of
https://github.com/Vision-CAIR/MiniGPT-4.git
synced 2025-04-05 02:20:47 +00:00
Update modeling_llama.py for transformers package compatibility
This commit is contained in:
parent
10f61a4dd8
commit
41c050de76
@ -75,7 +75,7 @@ class LlamaForCausalLM(LlamaForCausalLMOrig):
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
if self.config.pretraining_tp > 1:
|
||||
if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1:
|
||||
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
||||
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
||||
logits = torch.cat(logits, dim=-1)
|
||||
|
Loading…
Reference in New Issue
Block a user