diff --git a/MiniGPTv2.pdf b/MiniGPTv2.pdf deleted file mode 100644 index 04de5e8..0000000 Binary files a/MiniGPTv2.pdf and /dev/null differ diff --git a/Post_Route_Universal_PromptMoE_RawProb_backward_graph b/Post_Route_Universal_PromptMoE_RawProb_backward_graph new file mode 100644 index 0000000..25ea47e --- /dev/null +++ b/Post_Route_Universal_PromptMoE_RawProb_backward_graph @@ -0,0 +1,6290 @@ +digraph { + graph [size="885.0,885.0"] + node [align=left fontname=monospace fontsize=10 height=0.2 ranksep=0.1 shape=box style=filled] + 140193037219536 [label=" + (4, 45, 768)" fillcolor=darkolivegreen1] + 140193039136752 [label=CatBackward0] + 140193570151248 -> 140193039136752 + 140193570151248 [label=AddBackward0] + 140193039092752 -> 140193570151248 + 140193039092752 [label=IndexBackward0] + 140193578428064 -> 140193039092752 + 140193578428064 [label=NativeLayerNormBackward0] + 140193578427440 -> 140193578428064 + 140193578427440 [label=AddBackward0] + 140193578427632 -> 140193578427440 + 140193578427632 [label=CatBackward0] + 140193578427584 -> 140193578427632 + 140193578427584 [label=CatBackward0] + 140193578428016 -> 140193578427584 + 140193578428016 [label=SliceBackward0] + 140193578428304 -> 140193578428016 + 140193578428304 [label=SliceBackward0] + 140193578427248 -> 140193578428304 + 140193578427248 [label=SliceBackward0] + 140193578427152 -> 140193578427248 + 140193578427152 [label=SumBackward1] + 140193578427056 -> 140193578427152 + 140193578427056 [label=MulBackward0] + 140193578426960 -> 140193578427056 + 140193578426960 [label=IndexBackward0] + 140193578428160 -> 140193578426960 + 140193578428160 [label=PermuteBackward0] + 140193578428256 -> 140193578428160 + 140193578428256 [label=CatBackward0] + 140193578428352 -> 140193578428256 + 140193578428352 [label=UnsqueezeBackward0] + 140193578428688 -> 140193578428352 + 140193578428688 [label=NativeDropoutBackward0] + 140193578428784 -> 140193578428688 + 140193578428784 [label=ViewBackward0] + 140193578428880 -> 140193578428784 + 140193578428880 [label=AddmmBackward0] + 140193578428976 -> 140193578428880 + 140193578428976 [label=ToCopyBackward0] + 140193578429168 -> 140193578428976 + 140193039388000 [label="encoder.layer.11.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140193039388000 -> 140193578429168 + 140193578429168 [label=AccumulateGrad] + 140193578428736 -> 140193578428880 + 140193578428736 [label=ViewBackward0] + 140193578429024 -> 140193578428736 + 140193578429024 [label=GeluBackward0] + 140193578429120 -> 140193578429024 + 140193578429120 [label=ViewBackward0] + 140193578429216 -> 140193578429120 + 140193578429216 [label=AddmmBackward0] + 140193578429312 -> 140193578429216 + 140193578429312 [label=ToCopyBackward0] + 140193578429504 -> 140193578429312 + 140193039388320 [label="encoder.layer.11.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140193039388320 -> 140193578429504 + 140193578429504 [label=AccumulateGrad] + 140193578429456 -> 140193578429216 + 140193578429456 [label=ViewBackward0] + 140193578429744 -> 140193578429456 + 140193578429744 [label=ToCopyBackward0] + 140193578429840 -> 140193578429744 + 140193578429840 [label=IndexBackward0] + 140193578427392 -> 140193578429840 + 140193578427392 [label=SliceBackward0] + 140193578429792 -> 140193578427392 + 140193578429792 [label=SliceBackward0] + 140193578429888 -> 140193578429792 + 140193578429888 [label=SliceBackward0] + 140193578429984 -> 140193578429888 + 140193578429984 [label=SliceBackward0] + 140193578430080 -> 140193578429984 + 140193578430080 [label=SliceBackward0] + 140193578430176 -> 140193578430080 + 140193578430176 [label=NativeLayerNormBackward0] + 140193578430272 -> 140193578430176 + 140193578430272 [label=AddBackward0] + 140193036800064 -> 140193578430272 + 140193036800064 [label=NativeDropoutBackward0] + 140193036800400 -> 140193036800064 + 140193036800400 [label=ViewBackward0] + 140193036800496 -> 140193036800400 + 140193036800496 [label=AddmmBackward0] + 140193036800592 -> 140193036800496 + 140193036800592 [label=ToCopyBackward0] + 140193036800784 -> 140193036800592 + 140193039417648 [label="encoder.layer.11.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039417648 -> 140193036800784 + 140193036800784 [label=AccumulateGrad] + 140193036800352 -> 140193036800496 + 140193036800352 [label=ViewBackward0] + 140193036800640 -> 140193036800352 + 140193036800640 [label=ViewBackward0] + 140193036800736 -> 140193036800640 + 140193036800736 [label=CloneBackward0] + 140193036800832 -> 140193036800736 + 140193036800832 [label=PermuteBackward0] + 140193036800928 -> 140193036800832 + 140193036800928 [label=UnsafeViewBackward0] + 140193036801024 -> 140193036800928 + 140193036801024 [label=BmmBackward0] + 140193036801120 -> 140193036801024 + 140193036801120 [label=ReshapeAliasBackward0] + 140193036801456 -> 140193036801120 + 140193036801456 [label=ExpandBackward0] + 140193036801552 -> 140193036801456 + 140193036801552 [label=ToCopyBackward0] + 140193036801648 -> 140193036801552 + 140193036801648 [label=NativeDropoutBackward0] + 140193036801744 -> 140193036801648 + 140193036801744 [label=SoftmaxBackward0] + 140193036801840 -> 140193036801744 + 140193036801840 [label=AddBackward0] + 140193036801936 -> 140193036801840 + 140193036801936 [label=DivBackward0] + 140193036802032 -> 140193036801936 + 140193036802032 [label=UnsafeViewBackward0] + 140193036802128 -> 140193036802032 + 140193036802128 [label=BmmBackward0] + 140193036802224 -> 140193036802128 + 140193036802224 [label=UnsafeViewBackward0] + 140193036802176 -> 140193036802224 + 140193036802176 [label=CloneBackward0] + 140193036802272 -> 140193036802176 + 140193036802272 [label=ExpandBackward0] + 140193036802368 -> 140193036802272 + 140193036802368 [label=PermuteBackward0] + 140193036802464 -> 140193036802368 + 140193036802464 [label=ViewBackward0] + 140193036802560 -> 140193036802464 + 140193036802560 [label=ViewBackward0] + 140193036802656 -> 140193036802560 + 140193036802656 [label=AddmmBackward0] + 140193036802752 -> 140193036802656 + 140193036802752 [label=ToCopyBackward0] + 140193036802944 -> 140193036802752 + 140193039418608 [label="encoder.layer.11.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039418608 -> 140193036802944 + 140193036802944 [label=AccumulateGrad] + 140193036802896 -> 140193036802656 + 140193036802896 [label=ViewBackward0] + 140193036803184 -> 140193036802896 + 140193036803184 [label=ToCopyBackward0] + 140193036800208 -> 140193036803184 + 140193036800208 [label=CatBackward0] + 140193036803136 -> 140193036800208 + 140193036803136 [label=NativeLayerNormBackward0] + 140193036803472 -> 140193036803136 + 140193036803472 [label=AddBackward0] + 140193036803664 -> 140193036803472 + 140193036803664 [label=CatBackward0] + 140193036803616 -> 140193036803664 + 140193036803616 [label=CatBackward0] + 140193036804048 -> 140193036803616 + 140193036804048 [label=SliceBackward0] + 140193036803904 -> 140193036804048 + 140193036803904 [label=SliceBackward0] + 140193036849360 -> 140193036803904 + 140193036849360 [label=SliceBackward0] + 140193036849456 -> 140193036849360 + 140193036849456 [label=SumBackward1] + 140193036849552 -> 140193036849456 + 140193036849552 [label=MulBackward0] + 140193036849648 -> 140193036849552 + 140193036849648 [label=IndexBackward0] + 140193036849792 -> 140193036849648 + 140193036849792 [label=PermuteBackward0] + 140193036849888 -> 140193036849792 + 140193036849888 [label=CatBackward0] + 140193036849984 -> 140193036849888 + 140193036849984 [label=UnsqueezeBackward0] + 140193036850128 -> 140193036849984 + 140193036850128 [label=NativeDropoutBackward0] + 140193036850224 -> 140193036850128 + 140193036850224 [label=ViewBackward0] + 140193036850320 -> 140193036850224 + 140193036850320 [label=AddmmBackward0] + 140193036850416 -> 140193036850320 + 140193036850416 [label=ToCopyBackward0] + 140193036850608 -> 140193036850416 + 140193039404464 [label="encoder.layer.10.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140193039404464 -> 140193036850608 + 140193036850608 [label=AccumulateGrad] + 140193036850368 -> 140193036850320 + 140193036850368 [label=ViewBackward0] + 140193036850656 -> 140193036850368 + 140193036850656 [label=GeluBackward0] + 140193036850752 -> 140193036850656 + 140193036850752 [label=ViewBackward0] + 140193036850848 -> 140193036850752 + 140193036850848 [label=AddmmBackward0] + 140193036850944 -> 140193036850848 + 140193036850944 [label=ToCopyBackward0] + 140193036851136 -> 140193036850944 + 140193039404704 [label="encoder.layer.10.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140193039404704 -> 140193036851136 + 140193036851136 [label=AccumulateGrad] + 140193036850896 -> 140193036850848 + 140193036850896 [label=ViewBackward0] + 140193036851184 -> 140193036850896 + 140193036851184 [label=ToCopyBackward0] + 140193036851280 -> 140193036851184 + 140193036851280 [label=IndexBackward0] + 140193036803424 -> 140193036851280 + 140193036803424 [label=SliceBackward0] + 140193036851424 -> 140193036803424 + 140193036851424 [label=SliceBackward0] + 140193036851520 -> 140193036851424 + 140193036851520 [label=NativeLayerNormBackward0] + 140193036851616 -> 140193036851520 + 140193036851616 [label=AddBackward0] + 140193036851808 -> 140193036851616 + 140193036851808 [label=NativeDropoutBackward0] + 140193036851952 -> 140193036851808 + 140193036851952 [label=ViewBackward0] + 140193036852048 -> 140193036851952 + 140193036852048 [label=AddmmBackward0] + 140193036852144 -> 140193036852048 + 140193036852144 [label=ToCopyBackward0] + 140193036852336 -> 140193036852144 + 140193039420448 [label="encoder.layer.10.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039420448 -> 140193036852336 + 140193036852336 [label=AccumulateGrad] + 140193036852096 -> 140193036852048 + 140193036852096 [label=ViewBackward0] + 140193036852528 -> 140193036852096 + 140193036852528 [label=ViewBackward0] + 140193036852624 -> 140193036852528 + 140193036852624 [label=CloneBackward0] + 140193036852816 -> 140193036852624 + 140193036852816 [label=PermuteBackward0] + 140193036853008 -> 140193036852816 + 140193036853008 [label=UnsafeViewBackward0] + 140193036853104 -> 140193036853008 + 140193036853104 [label=BmmBackward0] + 140193036852960 -> 140193036853104 + 140193036852960 [label=ReshapeAliasBackward0] + 140193036890320 -> 140193036852960 + 140193036890320 [label=ExpandBackward0] + 140193036890368 -> 140193036890320 + 140193036890368 [label=ToCopyBackward0] + 140193036890608 -> 140193036890368 + 140193036890608 [label=NativeDropoutBackward0] + 140193036890800 -> 140193036890608 + 140193036890800 [label=SoftmaxBackward0] + 140193036890848 -> 140193036890800 + 140193036890848 [label=AddBackward0] + 140193036891088 -> 140193036890848 + 140193036891088 [label=DivBackward0] + 140193036891280 -> 140193036891088 + 140193036891280 [label=UnsafeViewBackward0] + 140193036891328 -> 140193036891280 + 140193036891328 [label=BmmBackward0] + 140193036891568 -> 140193036891328 + 140193036891568 [label=UnsafeViewBackward0] + 140193036891952 -> 140193036891568 + 140193036891952 [label=CloneBackward0] + 140193036892144 -> 140193036891952 + 140193036892144 [label=ExpandBackward0] + 140193036892336 -> 140193036892144 + 140193036892336 [label=PermuteBackward0] + 140193036892432 -> 140193036892336 + 140193036892432 [label=ViewBackward0] + 140193036892624 -> 140193036892432 + 140193036892624 [label=ViewBackward0] + 140193036892816 -> 140193036892624 + 140193036892816 [label=AddmmBackward0] + 140193036892912 -> 140193036892816 + 140193036892912 [label=ToCopyBackward0] + 140193036893296 -> 140193036892912 + 140193039421168 [label="encoder.layer.10.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140193039421168 -> 140193036893296 + 140193036893296 [label=AccumulateGrad] + 140193036892720 -> 140193036892816 + 140193036892720 [label=ViewBackward0] + 140193036893200 -> 140193036892720 + 140193036893200 [label=ToCopyBackward0] + 140193036851760 -> 140193036893200 + 140193036851760 [label=SliceBackward0] + 140193036893584 -> 140193036851760 + 140193036893584 [label=SliceBackward0] + 140193036893776 -> 140193036893584 + 140193036893776 [label=SliceBackward0] + 140193036893872 -> 140193036893776 + 140193036893872 [label=NativeLayerNormBackward0] + 140193036894064 -> 140193036893872 + 140193036894064 [label=AddBackward0] + 140193036914848 -> 140193036894064 + 140193036914848 [label=NativeDropoutBackward0] + 140193036914992 -> 140193036914848 + 140193036914992 [label=ViewBackward0] + 140193036915184 -> 140193036914992 + 140193036915184 [label=AddmmBackward0] + 140193036915232 -> 140193036915184 + 140193036915232 [label=ToCopyBackward0] + 140193036915664 -> 140193036915232 + 140193039429776 [label="encoder.layer.10.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039429776 -> 140193036915664 + 140193036915664 [label=AccumulateGrad] + 140193036915376 -> 140193036915184 + 140193036915376 [label=ViewBackward0] + 140193036915856 -> 140193036915376 + 140193036915856 [label=ViewBackward0] + 140193036916048 -> 140193036915856 + 140193036916048 [label=CloneBackward0] + 140193036916240 -> 140193036916048 + 140193036916240 [label=PermuteBackward0] + 140193036916336 -> 140193036916240 + 140193036916336 [label=UnsafeViewBackward0] + 140193036916528 -> 140193036916336 + 140193036916528 [label=BmmBackward0] + 140193036916720 -> 140193036916528 + 140193036916720 [label=ReshapeAliasBackward0] + 140193036916672 -> 140193036916720 + 140193036916672 [label=ExpandBackward0] + 140193036916912 -> 140193036916672 + 140193036916912 [label=ToCopyBackward0] + 140193036917104 -> 140193036916912 + 140193036917104 [label=NativeDropoutBackward0] + 140193036917152 -> 140193036917104 + 140193036917152 [label=SoftmaxBackward0] + 140193036917392 -> 140193036917152 + 140193036917392 [label=AddBackward0] + 140193036917584 -> 140193036917392 + 140193036917584 [label=DivBackward0] + 140193036917632 -> 140193036917584 + 140193036917632 [label=UnsafeViewBackward0] + 140193036917872 -> 140193036917632 + 140193036917872 [label=BmmBackward0] + 140193036918064 -> 140193036917872 + 140193036918064 [label=UnsafeViewBackward0] + 140193036918448 -> 140193036918064 + 140193036918448 [label=CloneBackward0] + 140193036918640 -> 140193036918448 + 140193036918640 [label=ExpandBackward0] + 140193036918736 -> 140193036918640 + 140193036918736 [label=PermuteBackward0] + 140193036918592 -> 140193036918736 + 140193036918592 [label=ViewBackward0] + 140193036947856 -> 140193036918592 + 140193036947856 [label=ViewBackward0] + 140193036947952 -> 140193036947856 + 140193036947952 [label=AddmmBackward0] + 140193036948144 -> 140193036947952 + 140193036948144 [label=ToCopyBackward0] + 140193036948432 -> 140193036948144 + 140193039432336 [label="encoder.layer.10.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039432336 -> 140193036948432 + 140193036948432 [label=AccumulateGrad] + 140193036947808 -> 140193036947952 + 140193036947808 [label=ViewBackward0] + 140193036948288 -> 140193036947808 + 140193036948288 [label=ToCopyBackward0] + 140193036914800 -> 140193036948288 + 140193036914800 [label=CatBackward0] + 140193036948816 -> 140193036914800 + 140193036948816 [label=NativeLayerNormBackward0] + 140193036948768 -> 140193036948816 + 140193036948768 [label=AddBackward0] + 140193036949200 -> 140193036948768 + 140193036949200 [label=CatBackward0] + 140193036949584 -> 140193036949200 + 140193036949584 [label=CatBackward0] + 140193036949728 -> 140193036949584 + 140193036949728 [label=SliceBackward0] + 140193036950256 -> 140193036949728 + 140193036950256 [label=SliceBackward0] + 140193036950352 -> 140193036950256 + 140193036950352 [label=SliceBackward0] + 140193036950544 -> 140193036950352 + 140193036950544 [label=SumBackward1] + 140193036950736 -> 140193036950544 + 140193036950736 [label=MulBackward0] + 140193036950832 -> 140193036950736 + 140193036950832 [label=IndexBackward0] + 140193036950928 -> 140193036950832 + 140193036950928 [label=PermuteBackward0] + 140193036951120 -> 140193036950928 + 140193036951120 [label=CatBackward0] + 140193036951168 -> 140193036951120 + 140193036951168 [label=UnsqueezeBackward0] + 140193036951408 -> 140193036951168 + 140193036951408 [label=NativeDropoutBackward0] + 140193036972336 -> 140193036951408 + 140193036972336 [label=ViewBackward0] + 140193036972528 -> 140193036972336 + 140193036972528 [label=AddmmBackward0] + 140193036972720 -> 140193036972528 + 140193036972720 [label=ToCopyBackward0] + 140193036973008 -> 140193036972720 + 140193039431136 [label="encoder.layer.9.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140193039431136 -> 140193036973008 + 140193036973008 [label=AccumulateGrad] + 140193036972432 -> 140193036972528 + 140193036972432 [label=ViewBackward0] + 140193036972912 -> 140193036972432 + 140193036972912 [label=GeluBackward0] + 140193036973104 -> 140193036972912 + 140193036973104 [label=ViewBackward0] + 140193036973152 -> 140193036973104 + 140193036973152 [label=AddmmBackward0] + 140193036973392 -> 140193036973152 + 140193036973392 [label=ToCopyBackward0] + 140193036973632 -> 140193036973392 + 140193039431456 [label="encoder.layer.9.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140193039431456 -> 140193036973632 + 140193036973632 [label=AccumulateGrad] + 140193036973488 -> 140193036973152 + 140193036973488 [label=ViewBackward0] + 140193036973968 -> 140193036973488 + 140193036973968 [label=ToCopyBackward0] + 140193036974160 -> 140193036973968 + 140193036974160 [label=IndexBackward0] + 140193036949296 -> 140193036974160 + 140193036949296 [label=SliceBackward0] + 140193036974112 -> 140193036949296 + 140193036974112 [label=SliceBackward0] + 140193036974352 -> 140193036974112 + 140193036974352 [label=SliceBackward0] + 140193036974544 -> 140193036974352 + 140193036974544 [label=SliceBackward0] + 140193036974592 -> 140193036974544 + 140193036974592 [label=SliceBackward0] + 140193036974832 -> 140193036974592 + 140193036974832 [label=NativeLayerNormBackward0] + 140193036975024 -> 140193036974832 + 140193036975024 [label=AddBackward0] + 140193036975312 -> 140193036975024 + 140193036975312 [label=NativeDropoutBackward0] + 140193036975696 -> 140193036975312 + 140193036975696 [label=ViewBackward0] + 140193036975888 -> 140193036975696 + 140193036975888 [label=AddmmBackward0] + 140193036976080 -> 140193036975888 + 140193036976080 [label=ToCopyBackward0] + 140193036480816 -> 140193036976080 + 140193039442304 [label="encoder.layer.9.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039442304 -> 140193036480816 + 140193036480816 [label=AccumulateGrad] + 140193036975792 -> 140193036975888 + 140193036975792 [label=ViewBackward0] + 140193036480720 -> 140193036975792 + 140193036480720 [label=ViewBackward0] + 140193036480912 -> 140193036480720 + 140193036480912 [label=CloneBackward0] + 140193036480960 -> 140193036480912 + 140193036480960 [label=PermuteBackward0] + 140193036481200 -> 140193036480960 + 140193036481200 [label=UnsafeViewBackward0] + 140193036481392 -> 140193036481200 + 140193036481392 [label=BmmBackward0] + 140193036481440 -> 140193036481392 + 140193036481440 [label=ReshapeAliasBackward0] + 140193036481968 -> 140193036481440 + 140193036481968 [label=ExpandBackward0] + 140193036482064 -> 140193036481968 + 140193036482064 [label=ToCopyBackward0] + 140193036482256 -> 140193036482064 + 140193036482256 [label=NativeDropoutBackward0] + 140193036482448 -> 140193036482256 + 140193036482448 [label=SoftmaxBackward0] + 140193036482544 -> 140193036482448 + 140193036482544 [label=AddBackward0] + 140193036482736 -> 140193036482544 + 140193036482736 [label=DivBackward0] + 140193036482928 -> 140193036482736 + 140193036482928 [label=UnsafeViewBackward0] + 140193036483024 -> 140193036482928 + 140193036483024 [label=BmmBackward0] + 140193036483216 -> 140193036483024 + 140193036483216 [label=UnsafeViewBackward0] + 140193036483312 -> 140193036483216 + 140193036483312 [label=CloneBackward0] + 140193036483360 -> 140193036483312 + 140193036483360 [label=ExpandBackward0] + 140193036483600 -> 140193036483360 + 140193036483600 [label=PermuteBackward0] + 140193036483792 -> 140193036483600 + 140193036483792 [label=ViewBackward0] + 140193036483840 -> 140193036483792 + 140193036483840 [label=ViewBackward0] + 140193036484080 -> 140193036483840 + 140193036484080 [label=AddmmBackward0] + 140193036484272 -> 140193036484080 + 140193036484272 [label=ToCopyBackward0] + 140193036484320 -> 140193036484272 + 140193039445264 [label="encoder.layer.9.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039445264 -> 140193036484320 + 140193036484320 [label=AccumulateGrad] + 140193036484368 -> 140193036484080 + 140193036484368 [label=ViewBackward0] + 140193036513584 -> 140193036484368 + 140193036513584 [label=ToCopyBackward0] + 140193036975408 -> 140193036513584 + 140193036975408 [label=CatBackward0] + 140193036513536 -> 140193036975408 + 140193036513536 [label=NativeLayerNormBackward0] + 140193036513968 -> 140193036513536 + 140193036513968 [label=AddBackward0] + 140193036514256 -> 140193036513968 + 140193036514256 [label=CatBackward0] + 140193036514640 -> 140193036514256 + 140193036514640 [label=CatBackward0] + 140193036514928 -> 140193036514640 + 140193036514928 [label=SliceBackward0] + 140193036515312 -> 140193036514928 + 140193036515312 [label=SliceBackward0] + 140193036515504 -> 140193036515312 + 140193036515504 [label=SliceBackward0] + 140193036515600 -> 140193036515504 + 140193036515600 [label=SumBackward1] + 140193036515792 -> 140193036515600 + 140193036515792 [label=MulBackward0] + 140193036515984 -> 140193036515792 + 140193036515984 [label=IndexBackward0] + 140193036515936 -> 140193036515984 + 140193036515936 [label=PermuteBackward0] + 140193036516176 -> 140193036515936 + 140193036516176 [label=CatBackward0] + 140193036516368 -> 140193036516176 + 140193036516368 [label=UnsqueezeBackward0] + 140193036516752 -> 140193036516368 + 140193036516752 [label=NativeDropoutBackward0] + 140193036516944 -> 140193036516752 + 140193036516944 [label=ViewBackward0] + 140193036517040 -> 140193036516944 + 140193036517040 [label=AddmmBackward0] + 140193036517232 -> 140193036517040 + 140193036517232 [label=ToCopyBackward0] + 140193036533920 -> 140193036517232 + 140193039444064 [label="encoder.layer.8.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140193039444064 -> 140193036533920 + 140193036533920 [label=AccumulateGrad] + 140193036516896 -> 140193036517040 + 140193036516896 [label=ViewBackward0] + 140193036533968 -> 140193036516896 + 140193036533968 [label=GeluBackward0] + 140193036534064 -> 140193036533968 + 140193036534064 [label=ViewBackward0] + 140193036534256 -> 140193036534064 + 140193036534256 [label=AddmmBackward0] + 140193036534304 -> 140193036534256 + 140193036534304 [label=ToCopyBackward0] + 140193036534736 -> 140193036534304 + 140193039443984 [label="encoder.layer.8.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140193039443984 -> 140193036534736 + 140193036534736 [label=AccumulateGrad] + 140193036534448 -> 140193036534256 + 140193036534448 [label=ViewBackward0] + 140193036534928 -> 140193036534448 + 140193036534928 [label=ToCopyBackward0] + 140193036535120 -> 140193036534928 + 140193036535120 [label=IndexBackward0] + 140193036514352 -> 140193036535120 + 140193036514352 [label=SliceBackward0] + 140193036535216 -> 140193036514352 + 140193036535216 [label=SliceBackward0] + 140193036535264 -> 140193036535216 + 140193036535264 [label=NativeLayerNormBackward0] + 140193036535504 -> 140193036535264 + 140193036535504 [label=AddBackward0] + 140193036535744 -> 140193036535504 + 140193036535744 [label=NativeDropoutBackward0] + 140193036536272 -> 140193036535744 + 140193036536272 [label=ViewBackward0] + 140193036536368 -> 140193036536272 + 140193036536368 [label=AddmmBackward0] + 140193036536560 -> 140193036536368 + 140193036536560 [label=ToCopyBackward0] + 140193036536848 -> 140193036536560 + 140193039459488 [label="encoder.layer.8.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039459488 -> 140193036536848 + 140193036536848 [label=AccumulateGrad] + 140193036536224 -> 140193036536368 + 140193036536224 [label=ViewBackward0] + 140193036536704 -> 140193036536224 + 140193036536704 [label=ViewBackward0] + 140193036536944 -> 140193036536704 + 140193036536944 [label=CloneBackward0] + 140193036537136 -> 140193036536944 + 140193036537136 [label=PermuteBackward0] + 140193036537184 -> 140193036537136 + 140193036537184 [label=UnsafeViewBackward0] + 140193036537424 -> 140193036537184 + 140193036537424 [label=BmmBackward0] + 140193036537616 -> 140193036537424 + 140193036537616 [label=ReshapeAliasBackward0] + 140193036537664 -> 140193036537616 + 140193036537664 [label=ExpandBackward0] + 140193036566928 -> 140193036537664 + 140193036566928 [label=ToCopyBackward0] + 140193036567024 -> 140193036566928 + 140193036567024 [label=NativeDropoutBackward0] + 140193036567216 -> 140193036567024 + 140193036567216 [label=SoftmaxBackward0] + 140193036567408 -> 140193036567216 + 140193036567408 [label=AddBackward0] + 140193036567504 -> 140193036567408 + 140193036567504 [label=DivBackward0] + 140193036567696 -> 140193036567504 + 140193036567696 [label=UnsafeViewBackward0] + 140193036567888 -> 140193036567696 + 140193036567888 [label=BmmBackward0] + 140193036567984 -> 140193036567888 + 140193036567984 [label=UnsafeViewBackward0] + 140193036568080 -> 140193036567984 + 140193036568080 [label=CloneBackward0] + 140193036568272 -> 140193036568080 + 140193036568272 [label=ExpandBackward0] + 140193036568320 -> 140193036568272 + 140193036568320 [label=PermuteBackward0] + 140193036568560 -> 140193036568320 + 140193036568560 [label=ViewBackward0] + 140193036568752 -> 140193036568560 + 140193036568752 [label=ViewBackward0] + 140193036568800 -> 140193036568752 + 140193036568800 [label=AddmmBackward0] + 140193036569040 -> 140193036568800 + 140193036569040 [label=ToCopyBackward0] + 140193036569280 -> 140193036569040 + 140193039460208 [label="encoder.layer.8.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140193039460208 -> 140193036569280 + 140193036569280 [label=AccumulateGrad] + 140193036569136 -> 140193036568800 + 140193036569136 [label=ViewBackward0] + 140193036569616 -> 140193036569136 + 140193036569616 [label=ToCopyBackward0] + 140193036535888 -> 140193036569616 + 140193036535888 [label=SliceBackward0] + 140193036569712 -> 140193036535888 + 140193036569712 [label=SliceBackward0] + 140193036569760 -> 140193036569712 + 140193036569760 [label=SliceBackward0] + 140193036570000 -> 140193036569760 + 140193036570000 [label=NativeLayerNormBackward0] + 140193036570192 -> 140193036570000 + 140193036570192 [label=AddBackward0] + 140193036570480 -> 140193036570192 + 140193036570480 [label=NativeDropoutBackward0] + 140193036591408 -> 140193036570480 + 140193036591408 [label=ViewBackward0] + 140193036591600 -> 140193036591408 + 140193036591600 [label=AddmmBackward0] + 140193036591792 -> 140193036591600 + 140193036591792 [label=ToCopyBackward0] + 140193036592080 -> 140193036591792 + 140193039460528 [label="encoder.layer.8.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039460528 -> 140193036592080 + 140193036592080 [label=AccumulateGrad] + 140193036591504 -> 140193036591600 + 140193036591504 [label=ViewBackward0] + 140193036591984 -> 140193036591504 + 140193036591984 [label=ViewBackward0] + 140193036592176 -> 140193036591984 + 140193036592176 [label=CloneBackward0] + 140193036592224 -> 140193036592176 + 140193036592224 [label=PermuteBackward0] + 140193036592464 -> 140193036592224 + 140193036592464 [label=UnsafeViewBackward0] + 140193036592656 -> 140193036592464 + 140193036592656 [label=BmmBackward0] + 140193036592704 -> 140193036592656 + 140193036592704 [label=ReshapeAliasBackward0] + 140193036593232 -> 140193036592704 + 140193036593232 [label=ExpandBackward0] + 140193036593328 -> 140193036593232 + 140193036593328 [label=ToCopyBackward0] + 140193036593520 -> 140193036593328 + 140193036593520 [label=NativeDropoutBackward0] + 140193036593712 -> 140193036593520 + 140193036593712 [label=SoftmaxBackward0] + 140193036593808 -> 140193036593712 + 140193036593808 [label=AddBackward0] + 140193036594000 -> 140193036593808 + 140193036594000 [label=DivBackward0] + 140193036594192 -> 140193036594000 + 140193036594192 [label=UnsafeViewBackward0] + 140193036594288 -> 140193036594192 + 140193036594288 [label=BmmBackward0] + 140193036594480 -> 140193036594288 + 140193036594480 [label=UnsafeViewBackward0] + 140193036594576 -> 140193036594480 + 140193036594576 [label=CloneBackward0] + 140193036594624 -> 140193036594576 + 140193036594624 [label=ExpandBackward0] + 140193036594864 -> 140193036594624 + 140193036594864 [label=PermuteBackward0] + 140193036595056 -> 140193036594864 + 140193036595056 [label=ViewBackward0] + 140193036594384 -> 140193036595056 + 140193036594384 [label=ViewBackward0] + 140193036628176 -> 140193036594384 + 140193036628176 [label=AddmmBackward0] + 140193036628368 -> 140193036628176 + 140193036628368 [label=ToCopyBackward0] + 140193036628656 -> 140193036628368 + 140193039467280 [label="encoder.layer.8.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039467280 -> 140193036628656 + 140193036628656 [label=AccumulateGrad] + 140193036628464 -> 140193036628176 + 140193036628464 [label=ViewBackward0] + 140193036628944 -> 140193036628464 + 140193036628944 [label=ToCopyBackward0] + 140193036570576 -> 140193036628944 + 140193036570576 [label=CatBackward0] + 140193036628896 -> 140193036570576 + 140193036628896 [label=NativeLayerNormBackward0] + 140193036629424 -> 140193036628896 + 140193036629424 [label=AddBackward0] + 140193036629712 -> 140193036629424 + 140193036629712 [label=CatBackward0] + 140193036629808 -> 140193036629712 + 140193036629808 [label=CatBackward0] + 140193036630384 -> 140193036629808 + 140193036630384 [label=SliceBackward0] + 140193036630336 -> 140193036630384 + 140193036630336 [label=SliceBackward0] + 140193036630576 -> 140193036630336 + 140193036630576 [label=SliceBackward0] + 140193036630768 -> 140193036630576 + 140193036630768 [label=SumBackward1] + 140193036630816 -> 140193036630768 + 140193036630816 [label=MulBackward0] + 140193036631056 -> 140193036630816 + 140193036631056 [label=IndexBackward0] + 140193036631440 -> 140193036631056 + 140193036631440 [label=PermuteBackward0] + 140193036631632 -> 140193036631440 + 140193036631632 [label=CatBackward0] + 140193036631824 -> 140193036631632 + 140193036631824 [label=UnsqueezeBackward0] + 140193036631776 -> 140193036631824 + 140193036631776 [label=NativeDropoutBackward0] + 140193036631728 -> 140193036631776 + 140193036631728 [label=ViewBackward0] + 140193036656848 -> 140193036631728 + 140193036656848 [label=AddmmBackward0] + 140193036656896 -> 140193036656848 + 140193036656896 [label=ToCopyBackward0] + 140193036657328 -> 140193036656896 + 140193039461888 [label="encoder.layer.7.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140193039461888 -> 140193036657328 + 140193036657328 [label=AccumulateGrad] + 140193036657040 -> 140193036656848 + 140193036657040 [label=ViewBackward0] + 140193036657520 -> 140193036657040 + 140193036657520 [label=GeluBackward0] + 140193036657712 -> 140193036657520 + 140193036657712 [label=ViewBackward0] + 140193036657904 -> 140193036657712 + 140193036657904 [label=AddmmBackward0] + 140193036658000 -> 140193036657904 + 140193036658000 [label=ToCopyBackward0] + 140193036658384 -> 140193036658000 + 140193039462208 [label="encoder.layer.7.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140193039462208 -> 140193036658384 + 140193036658384 [label=AccumulateGrad] + 140193036657808 -> 140193036657904 + 140193036657808 [label=ViewBackward0] + 140193036658288 -> 140193036657808 + 140193036658288 [label=ToCopyBackward0] + 140193036658336 -> 140193036658288 + 140193036658336 [label=IndexBackward0] + 140193036629376 -> 140193036658336 + 140193036629376 [label=SliceBackward0] + 140193036658864 -> 140193036629376 + 140193036658864 [label=SliceBackward0] + 140193036658960 -> 140193036658864 + 140193036658960 [label=SliceBackward0] + 140193036659152 -> 140193036658960 + 140193036659152 [label=SliceBackward0] + 140193036659344 -> 140193036659152 + 140193036659344 [label=SliceBackward0] + 140193036659440 -> 140193036659344 + 140193036659440 [label=NativeLayerNormBackward0] + 140193036659632 -> 140193036659440 + 140193036659632 [label=AddBackward0] + 140193036659920 -> 140193036659632 + 140193036659920 [label=NativeDropoutBackward0] + 140193036660016 -> 140193036659920 + 140193036660016 [label=ViewBackward0] + 140193036660208 -> 140193036660016 + 140193036660208 [label=AddmmBackward0] + 140193036660256 -> 140193036660208 + 140193036660256 [label=ToCopyBackward0] + 140193036660496 -> 140193036660256 + 140193039487200 [label="encoder.layer.7.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039487200 -> 140193036660496 + 140193036660496 [label=AccumulateGrad] + 140193036660400 -> 140193036660208 + 140193036660400 [label=ViewBackward0] + 140193036677280 -> 140193036660400 + 140193036677280 [label=ViewBackward0] + 140193036677520 -> 140193036677280 + 140193036677520 [label=CloneBackward0] + 140193036677712 -> 140193036677520 + 140193036677712 [label=PermuteBackward0] + 140193036677808 -> 140193036677712 + 140193036677808 [label=UnsafeViewBackward0] + 140193036678000 -> 140193036677808 + 140193036678000 [label=BmmBackward0] + 140193036678192 -> 140193036678000 + 140193036678192 [label=ReshapeAliasBackward0] + 140193036678144 -> 140193036678192 + 140193036678144 [label=ExpandBackward0] + 140193036678384 -> 140193036678144 + 140193036678384 [label=ToCopyBackward0] + 140193036678576 -> 140193036678384 + 140193036678576 [label=NativeDropoutBackward0] + 140193036678624 -> 140193036678576 + 140193036678624 [label=SoftmaxBackward0] + 140193036678864 -> 140193036678624 + 140193036678864 [label=AddBackward0] + 140193036679056 -> 140193036678864 + 140193036679056 [label=DivBackward0] + 140193036679104 -> 140193036679056 + 140193036679104 [label=UnsafeViewBackward0] + 140193036679344 -> 140193036679104 + 140193036679344 [label=BmmBackward0] + 140193036679536 -> 140193036679344 + 140193036679536 [label=UnsafeViewBackward0] + 140193036679920 -> 140193036679536 + 140193036679920 [label=CloneBackward0] + 140193036680112 -> 140193036679920 + 140193036680112 [label=ExpandBackward0] + 140193036680208 -> 140193036680112 + 140193036680208 [label=PermuteBackward0] + 140193036680400 -> 140193036680208 + 140193036680400 [label=ViewBackward0] + 140193036680592 -> 140193036680400 + 140193036680592 [label=ViewBackward0] + 140193036680688 -> 140193036680592 + 140193036680688 [label=AddmmBackward0] + 140193036680880 -> 140193036680688 + 140193036680880 [label=ToCopyBackward0] + 140193036681168 -> 140193036680880 + 140193039488080 [label="encoder.layer.7.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039488080 -> 140193036681168 + 140193036681168 [label=AccumulateGrad] + 140193036680544 -> 140193036680688 + 140193036680544 [label=ViewBackward0] + 140193036681072 -> 140193036680544 + 140193036681072 [label=ToCopyBackward0] + 140193036659728 -> 140193036681072 + 140193036659728 [label=CatBackward0] + 140193036710288 -> 140193036659728 + 140193036710288 [label=NativeLayerNormBackward0] + 140193036710240 -> 140193036710288 + 140193036710240 [label=AddBackward0] + 140193036710672 -> 140193036710240 + 140193036710672 [label=CatBackward0] + 140193036711056 -> 140193036710672 + 140193036711056 [label=CatBackward0] + 140193036711200 -> 140193036711056 + 140193036711200 [label=SliceBackward0] + 140193036711728 -> 140193036711200 + 140193036711728 [label=SliceBackward0] + 140193036711824 -> 140193036711728 + 140193036711824 [label=SliceBackward0] + 140193036712016 -> 140193036711824 + 140193036712016 [label=SumBackward1] + 140193036712208 -> 140193036712016 + 140193036712208 [label=MulBackward0] + 140193036712304 -> 140193036712208 + 140193036712304 [label=IndexBackward0] + 140193036712400 -> 140193036712304 + 140193036712400 [label=PermuteBackward0] + 140193036712592 -> 140193036712400 + 140193036712592 [label=ViewBackward0] + 140193036712640 -> 140193036712592 + 140193036712640 [label=CloneBackward0] + 140193036712880 -> 140193036712640 + 140193036712880 [label=ExpandBackward0] + 140193036713072 -> 140193036712880 + 140193036713072 [label=UnsqueezeBackward0] + 140193036713120 -> 140193036713072 + 140193036713120 [label=CatBackward0] + 140193036713360 -> 140193036713120 + 140193036713360 [label=UnsqueezeBackward0] + 140193036713744 -> 140193036713360 + 140193036713744 [label=NativeDropoutBackward0] + 140193036713936 -> 140193036713744 + 140193036713936 [label=ViewBackward0] + 140193036713840 -> 140193036713936 + 140193036713840 [label=AddmmBackward0] + 140193036210480 -> 140193036713840 + 140193036210480 [label=ToCopyBackward0] + 140193036210864 -> 140193036210480 + 140193039470480 [label="encoder.layer.6.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140193039470480 -> 140193036210864 + 140193036210864 [label=AccumulateGrad] + 140193036210288 -> 140193036713840 + 140193036210288 [label=ViewBackward0] + 140193036210768 -> 140193036210288 + 140193036210768 [label=GeluBackward0] + 140193036210816 -> 140193036210768 + 140193036210816 [label=ViewBackward0] + 140193036211056 -> 140193036210816 + 140193036211056 [label=AddmmBackward0] + 140193036211248 -> 140193036211056 + 140193036211248 [label=ToCopyBackward0] + 140193036211536 -> 140193036211248 + 140193039470400 [label="encoder.layer.6.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140193039470400 -> 140193036211536 + 140193036211536 [label=AccumulateGrad] + 140193036211344 -> 140193036211056 + 140193036211344 [label=ViewBackward0] + 140193036211824 -> 140193036211344 + 140193036211824 [label=ToCopyBackward0] + 140193036211920 -> 140193036211824 + 140193036211920 [label=SliceBackward0] + 140193036212112 -> 140193036211920 + 140193036212112 [label=SliceBackward0] + 140193036212304 -> 140193036212112 + 140193036212304 [label=NativeLayerNormBackward0] + 140193036212400 -> 140193036212304 + 140193036212400 [label=AddBackward0] + 140193036212784 -> 140193036212400 + 140193036212784 [label=NativeDropoutBackward0] + 140193036212736 -> 140193036212784 + 140193036212736 [label=ViewBackward0] + 140193036212976 -> 140193036212736 + 140193036212976 [label=AddmmBackward0] + 140193036213168 -> 140193036212976 + 140193036213168 [label=ToCopyBackward0] + 140193036213456 -> 140193036213168 + 140193039490320 [label="encoder.layer.6.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039490320 -> 140193036213456 + 140193036213456 [label=AccumulateGrad] + 140193036213264 -> 140193036212976 + 140193036213264 [label=ViewBackward0] + 140193036213744 -> 140193036213264 + 140193036213744 [label=ViewBackward0] + 140193036213840 -> 140193036213744 + 140193036213840 [label=CloneBackward0] + 140193036214032 -> 140193036213840 + 140193036214032 [label=PermuteBackward0] + 140193036214224 -> 140193036214032 + 140193036214224 [label=UnsafeViewBackward0] + 140193036214128 -> 140193036214224 + 140193036214128 [label=BmmBackward0] + 140193036247344 -> 140193036214128 + 140193036247344 [label=ReshapeAliasBackward0] + 140193036247440 -> 140193036247344 + 140193036247440 [label=ExpandBackward0] + 140193036247488 -> 140193036247440 + 140193036247488 [label=ToCopyBackward0] + 140193036247728 -> 140193036247488 + 140193036247728 [label=NativeDropoutBackward0] + 140193036247920 -> 140193036247728 + 140193036247920 [label=SoftmaxBackward0] + 140193036247968 -> 140193036247920 + 140193036247968 [label=AddBackward0] + 140193036248208 -> 140193036247968 + 140193036248208 [label=DivBackward0] + 140193036248400 -> 140193036248208 + 140193036248400 [label=UnsafeViewBackward0] + 140193036248448 -> 140193036248400 + 140193036248448 [label=BmmBackward0] + 140193036248688 -> 140193036248448 + 140193036248688 [label=UnsafeViewBackward0] + 140193036249072 -> 140193036248688 + 140193036249072 [label=CloneBackward0] + 140193036249264 -> 140193036249072 + 140193036249264 [label=ExpandBackward0] + 140193036249456 -> 140193036249264 + 140193036249456 [label=PermuteBackward0] + 140193036249552 -> 140193036249456 + 140193036249552 [label=ViewBackward0] + 140193036249744 -> 140193036249552 + 140193036249744 [label=ViewBackward0] + 140193036249936 -> 140193036249744 + 140193036249936 [label=AddmmBackward0] + 140193036250032 -> 140193036249936 + 140193036250032 [label=ToCopyBackward0] + 140193036250416 -> 140193036250032 + 140193039490960 [label="encoder.layer.6.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140193039490960 -> 140193036250416 + 140193036250416 [label=AccumulateGrad] + 140193036249840 -> 140193036249936 + 140193036249840 [label=ViewBackward0] + 140193036250320 -> 140193036249840 + 140193036250320 [label=ToCopyBackward0] + 140193036212496 -> 140193036250320 + 140193036212496 [label=SliceBackward0] + 140193036250704 -> 140193036212496 + 140193036250704 [label=SliceBackward0] + 140193036250896 -> 140193036250704 + 140193036250896 [label=SliceBackward0] + 140193036250992 -> 140193036250896 + 140193036250992 [label=NativeLayerNormBackward0] + 140193036250848 -> 140193036250992 + 140193036250848 [label=AddBackward0] + 140193036272016 -> 140193036250848 + 140193036272016 [label=NativeDropoutBackward0] + 140193036272112 -> 140193036272016 + 140193036272112 [label=ViewBackward0] + 140193036272304 -> 140193036272112 + 140193036272304 [label=AddmmBackward0] + 140193036272352 -> 140193036272304 + 140193036272352 [label=ToCopyBackward0] + 140193036272784 -> 140193036272352 + 140193039495712 [label="encoder.layer.6.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039495712 -> 140193036272784 + 140193036272784 [label=AccumulateGrad] + 140193036272496 -> 140193036272304 + 140193036272496 [label=ViewBackward0] + 140193036272976 -> 140193036272496 + 140193036272976 [label=ViewBackward0] + 140193036273168 -> 140193036272976 + 140193036273168 [label=CloneBackward0] + 140193036273360 -> 140193036273168 + 140193036273360 [label=PermuteBackward0] + 140193036273456 -> 140193036273360 + 140193036273456 [label=UnsafeViewBackward0] + 140193036273648 -> 140193036273456 + 140193036273648 [label=BmmBackward0] + 140193036273840 -> 140193036273648 + 140193036273840 [label=ReshapeAliasBackward0] + 140193036273792 -> 140193036273840 + 140193036273792 [label=ExpandBackward0] + 140193036274032 -> 140193036273792 + 140193036274032 [label=ToCopyBackward0] + 140193036274224 -> 140193036274032 + 140193036274224 [label=NativeDropoutBackward0] + 140193036274608 -> 140193036274224 + 140193036274608 [label=SoftmaxBackward0] + 140193036274800 -> 140193036274608 + 140193036274800 [label=AddBackward0] + 140193036274896 -> 140193036274800 + 140193036274896 [label=DivBackward0] + 140193036275088 -> 140193036274896 + 140193036275088 [label=UnsafeViewBackward0] + 140193036275280 -> 140193036275088 + 140193036275280 [label=BmmBackward0] + 140193036275376 -> 140193036275280 + 140193036275376 [label=UnsafeViewBackward0] + 140193036275472 -> 140193036275376 + 140193036275472 [label=CloneBackward0] + 140193036275232 -> 140193036275472 + 140193036275232 [label=ExpandBackward0] + 140193036302992 -> 140193036275232 + 140193036302992 [label=PermuteBackward0] + 140193036303760 -> 140193036302992 + 140193036303760 [label=ViewBackward0] + 140193036303712 -> 140193036303760 + 140193036303712 [label=ViewBackward0] + 140193036300400 -> 140193036303712 + 140193036300400 [label=AddmmBackward0] + 140193036300496 -> 140193036300400 + 140193036300496 [label=ToCopyBackward0] + 140193036300784 -> 140193036300496 + 140193039496432 [label="encoder.layer.6.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039496432 -> 140193036300784 + 140193036300784 [label=AccumulateGrad] + 140193036300448 -> 140193036300400 + 140193036300448 [label=ViewBackward0] + 140193036300976 -> 140193036300448 + 140193036300976 [label=ToCopyBackward0] + 140193036271824 -> 140193036300976 + 140193036271824 [label=CatBackward0] + 140193036301072 -> 140193036271824 + 140193036301072 [label=NativeLayerNormBackward0] + 140193036301456 -> 140193036301072 + 140193036301456 [label=AddBackward0] + 140193036301648 -> 140193036301456 + 140193036301648 [label=NativeDropoutBackward0] + 140193036301792 -> 140193036301648 + 140193036301792 [label=ViewBackward0] + 140193036302032 -> 140193036301792 + 140193036302032 [label=AddmmBackward0] + 140193036302224 -> 140193036302032 + 140193036302224 [label=ToCopyBackward0] + 140193036303568 -> 140193036302224 + 140193039496912 [label="encoder.layer.5.experts.dense2.bias + (768)" fillcolor=lightblue] + 140193039496912 -> 140193036303568 + 140193036303568 [label=AccumulateGrad] + 140193036302128 -> 140193036302032 + 140193036302128 [label=ViewBackward0] + 140193036302608 -> 140193036302128 + 140193036302608 [label=GeluBackward0] + 140193036303280 -> 140193036302608 + 140193036303280 [label=ViewBackward0] + 140193036302896 -> 140193036303280 + 140193036302896 [label=AddmmBackward0] + 140193036304240 -> 140193036302896 + 140193036304240 [label=ToCopyBackward0] + 140193036302272 -> 140193036304240 + 140193039497152 [label="encoder.layer.5.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140193039497152 -> 140193036302272 + 140193036302272 [label=AccumulateGrad] + 140193036303664 -> 140193036302896 + 140193036303664 [label=ViewBackward0] + 140193036303184 -> 140193036303664 + 140193036303184 [label=ToCopyBackward0] + 140193036301552 -> 140193036303184 + 140193036301552 [label=SliceBackward0] + 140193037273120 -> 140193036301552 + 140193037273120 [label=SliceBackward0] + 140193037273024 -> 140193037273120 + 140193037273024 [label=SliceBackward0] + 140193037272928 -> 140193037273024 + 140193037272928 [label=SliceBackward0] + 140193037272832 -> 140193037272928 + 140193037272832 [label=SliceBackward0] + 140193037272736 -> 140193037272832 + 140193037272736 [label=NativeLayerNormBackward0] + 140193037272640 -> 140193037272736 + 140193037272640 [label=AddBackward0] + 140193037272448 -> 140193037272640 + 140193037272448 [label=NativeDropoutBackward0] + 140193037272400 -> 140193037272448 + 140193037272400 [label=ViewBackward0] + 140193037272304 -> 140193037272400 + 140193037272304 [label=AddmmBackward0] + 140193037272208 -> 140193037272304 + 140193037272208 [label=ToCopyBackward0] + 140193037272016 -> 140193037272208 + 140193039499072 [label="encoder.layer.5.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039499072 -> 140193037272016 + 140193037272016 [label=AccumulateGrad] + 140193037272160 -> 140193037272304 + 140193037272160 [label=ViewBackward0] + 140193037271872 -> 140193037272160 + 140193037271872 [label=ViewBackward0] + 140193037271776 -> 140193037271872 + 140193037271776 [label=CloneBackward0] + 140193037271680 -> 140193037271776 + 140193037271680 [label=PermuteBackward0] + 140193037271584 -> 140193037271680 + 140193037271584 [label=UnsafeViewBackward0] + 140193037271488 -> 140193037271584 + 140193037271488 [label=BmmBackward0] + 140193037271392 -> 140193037271488 + 140193037271392 [label=ReshapeAliasBackward0] + 140193037271344 -> 140193037271392 + 140193037271344 [label=ExpandBackward0] + 140193037271248 -> 140193037271344 + 140193037271248 [label=ToCopyBackward0] + 140193037271152 -> 140193037271248 + 140193037271152 [label=NativeDropoutBackward0] + 140193037273408 -> 140193037271152 + 140193037273408 [label=SoftmaxBackward0] + 140193037273504 -> 140193037273408 + 140193037273504 [label=AddBackward0] + 140193037273600 -> 140193037273504 + 140193037273600 [label=DivBackward0] + 140193037273696 -> 140193037273600 + 140193037273696 [label=UnsafeViewBackward0] + 140193037273792 -> 140193037273696 + 140193037273792 [label=BmmBackward0] + 140193037273888 -> 140193037273792 + 140193037273888 [label=UnsafeViewBackward0] + 140193037274032 -> 140193037273888 + 140193037274032 [label=CloneBackward0] + 140193037274128 -> 140193037274032 + 140193037274128 [label=ExpandBackward0] + 140193037274224 -> 140193037274128 + 140193037274224 [label=PermuteBackward0] + 140193037274320 -> 140193037274224 + 140193037274320 [label=ViewBackward0] + 140193037274416 -> 140193037274320 + 140193037274416 [label=ViewBackward0] + 140193037274512 -> 140193037274416 + 140193037274512 [label=AddmmBackward0] + 140193037274608 -> 140193037274512 + 140193037274608 [label=ToCopyBackward0] + 140193037274800 -> 140193037274608 + 140193039516272 [label="encoder.layer.5.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039516272 -> 140193037274800 + 140193037274800 [label=AccumulateGrad] + 140193037274560 -> 140193037274512 + 140193037274560 [label=ViewBackward0] + 140193037274848 -> 140193037274560 + 140193037274848 [label=ToCopyBackward0] + 140193037272592 -> 140193037274848 + 140193037272592 [label=CatBackward0] + 140193037274992 -> 140193037272592 + 140193037274992 [label=NativeLayerNormBackward0] + 140193037275088 -> 140193037274992 + 140193037275088 [label=AddBackward0] + 140193037361408 -> 140193037275088 + 140193037361408 [label=NativeDropoutBackward0] + 140193037361552 -> 140193037361408 + 140193037361552 [label=ViewBackward0] + 140193037361648 -> 140193037361552 + 140193037361648 [label=AddmmBackward0] + 140193037361744 -> 140193037361648 + 140193037361744 [label=ToCopyBackward0] + 140193037361936 -> 140193037361744 + 140193039516752 [label="encoder.layer.4.experts.dense2.bias + (768)" fillcolor=lightblue] + 140193039516752 -> 140193037361936 + 140193037361936 [label=AccumulateGrad] + 140193037361696 -> 140193037361648 + 140193037361696 [label=ViewBackward0] + 140193037361984 -> 140193037361696 + 140193037361984 [label=GeluBackward0] + 140193037362080 -> 140193037361984 + 140193037362080 [label=ViewBackward0] + 140193037362176 -> 140193037362080 + 140193037362176 [label=AddmmBackward0] + 140193037362272 -> 140193037362176 + 140193037362272 [label=ToCopyBackward0] + 140193037362464 -> 140193037362272 + 140193039516992 [label="encoder.layer.4.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140193039516992 -> 140193037362464 + 140193037362464 [label=AccumulateGrad] + 140193037362224 -> 140193037362176 + 140193037362224 [label=ViewBackward0] + 140193037362512 -> 140193037362224 + 140193037362512 [label=ToCopyBackward0] + 140193037361360 -> 140193037362512 + 140193037361360 [label=SliceBackward0] + 140193037362656 -> 140193037361360 + 140193037362656 [label=SliceBackward0] + 140193037362752 -> 140193037362656 + 140193037362752 [label=NativeLayerNormBackward0] + 140193037362848 -> 140193037362752 + 140193037362848 [label=AddBackward0] + 140193037363040 -> 140193037362848 + 140193037363040 [label=NativeDropoutBackward0] + 140193037363184 -> 140193037363040 + 140193037363184 [label=ViewBackward0] + 140193037363280 -> 140193037363184 + 140193037363280 [label=AddmmBackward0] + 140193037363376 -> 140193037363280 + 140193037363376 [label=ToCopyBackward0] + 140193037363568 -> 140193037363376 + 140193039518592 [label="encoder.layer.4.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039518592 -> 140193037363568 + 140193037363568 [label=AccumulateGrad] + 140193037363328 -> 140193037363280 + 140193037363328 [label=ViewBackward0] + 140193037363616 -> 140193037363328 + 140193037363616 [label=ViewBackward0] + 140193037363712 -> 140193037363616 + 140193037363712 [label=CloneBackward0] + 140193037363808 -> 140193037363712 + 140193037363808 [label=PermuteBackward0] + 140193037363904 -> 140193037363808 + 140193037363904 [label=UnsafeViewBackward0] + 140193037364000 -> 140193037363904 + 140193037364000 [label=BmmBackward0] + 140193037364096 -> 140193037364000 + 140193037364096 [label=ReshapeAliasBackward0] + 140193037364240 -> 140193037364096 + 140193037364240 [label=ExpandBackward0] + 140193037364336 -> 140193037364240 + 140193037364336 [label=ToCopyBackward0] + 140193037364432 -> 140193037364336 + 140193037364432 [label=NativeDropoutBackward0] + 140193037364528 -> 140193037364432 + 140193037364528 [label=SoftmaxBackward0] + 140193037364624 -> 140193037364528 + 140193037364624 [label=AddBackward0] + 140193037364720 -> 140193037364624 + 140193037364720 [label=DivBackward0] + 140193037364816 -> 140193037364720 + 140193037364816 [label=UnsafeViewBackward0] + 140193037364912 -> 140193037364816 + 140193037364912 [label=BmmBackward0] + 140193037365008 -> 140193037364912 + 140193037365008 [label=UnsafeViewBackward0] + 140193037365152 -> 140193037365008 + 140193037365152 [label=CloneBackward0] + 140193037365200 -> 140193037365152 + 140193037365200 [label=ExpandBackward0] + 140193037275296 -> 140193037365200 + 140193037275296 [label=PermuteBackward0] + 140193037275392 -> 140193037275296 + 140193037275392 [label=ViewBackward0] + 140193037275488 -> 140193037275392 + 140193037275488 [label=ViewBackward0] + 140193037275584 -> 140193037275488 + 140193037275584 [label=AddmmBackward0] + 140193037275680 -> 140193037275584 + 140193037275680 [label=ToCopyBackward0] + 140193037275872 -> 140193037275680 + 140193039519312 [label="encoder.layer.4.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140193039519312 -> 140193037275872 + 140193037275872 [label=AccumulateGrad] + 140193037275632 -> 140193037275584 + 140193037275632 [label=ViewBackward0] + 140193037275920 -> 140193037275632 + 140193037275920 [label=ToCopyBackward0] + 140193037362992 -> 140193037275920 + 140193037362992 [label=SliceBackward0] + 140193037276064 -> 140193037362992 + 140193037276064 [label=SliceBackward0] + 140193037276160 -> 140193037276064 + 140193037276160 [label=SliceBackward0] + 140193037276256 -> 140193037276160 + 140193037276256 [label=NativeLayerNormBackward0] + 140193037276352 -> 140193037276256 + 140193037276352 [label=AddBackward0] + 140193037276544 -> 140193037276352 + 140193037276544 [label=NativeDropoutBackward0] + 140193037276688 -> 140193037276544 + 140193037276688 [label=ViewBackward0] + 140193037276784 -> 140193037276688 + 140193037276784 [label=AddmmBackward0] + 140193037276880 -> 140193037276784 + 140193037276880 [label=ToCopyBackward0] + 140193037277072 -> 140193037276880 + 140193039536272 [label="encoder.layer.4.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039536272 -> 140193037277072 + 140193037277072 [label=AccumulateGrad] + 140193037276832 -> 140193037276784 + 140193037276832 [label=ViewBackward0] + 140193037277120 -> 140193037276832 + 140193037277120 [label=ViewBackward0] + 140193037277216 -> 140193037277120 + 140193037277216 [label=CloneBackward0] + 140193037277312 -> 140193037277216 + 140193037277312 [label=PermuteBackward0] + 140193037277408 -> 140193037277312 + 140193037277408 [label=UnsafeViewBackward0] + 140193037277504 -> 140193037277408 + 140193037277504 [label=BmmBackward0] + 140193037277600 -> 140193037277504 + 140193037277600 [label=ReshapeAliasBackward0] + 140193037277744 -> 140193037277600 + 140193037277744 [label=ExpandBackward0] + 140193037277840 -> 140193037277744 + 140193037277840 [label=ToCopyBackward0] + 140193037277936 -> 140193037277840 + 140193037277936 [label=NativeDropoutBackward0] + 140193037278032 -> 140193037277936 + 140193037278032 [label=SoftmaxBackward0] + 140193037278128 -> 140193037278032 + 140193037278128 [label=AddBackward0] + 140193037278224 -> 140193037278128 + 140193037278224 [label=DivBackward0] + 140193037278320 -> 140193037278224 + 140193037278320 [label=UnsafeViewBackward0] + 140193037278416 -> 140193037278320 + 140193037278416 [label=BmmBackward0] + 140193037278512 -> 140193037278416 + 140193037278512 [label=UnsafeViewBackward0] + 140193037278656 -> 140193037278512 + 140193037278656 [label=CloneBackward0] + 140193037278752 -> 140193037278656 + 140193037278752 [label=ExpandBackward0] + 140193037278848 -> 140193037278752 + 140193037278848 [label=PermuteBackward0] + 140193037278944 -> 140193037278848 + 140193037278944 [label=ViewBackward0] + 140193037279040 -> 140193037278944 + 140193037279040 [label=ViewBackward0] + 140193037279136 -> 140193037279040 + 140193037279136 [label=AddmmBackward0] + 140193037279184 -> 140193037279136 + 140193037279184 [label=ToCopyBackward0] + 140193037259008 -> 140193037279184 + 140193039536992 [label="encoder.layer.4.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039536992 -> 140193037259008 + 140193037259008 [label=AccumulateGrad] + 140193037278560 -> 140193037279136 + 140193037278560 [label=ViewBackward0] + 140193037259056 -> 140193037278560 + 140193037259056 [label=ToCopyBackward0] + 140193037276496 -> 140193037259056 + 140193037276496 [label=CatBackward0] + 140193037259200 -> 140193037276496 + 140193037259200 [label=NativeLayerNormBackward0] + 140193037259344 -> 140193037259200 + 140193037259344 [label=AddBackward0] + 140193037259536 -> 140193037259344 + 140193037259536 [label=NativeDropoutBackward0] + 140193037259680 -> 140193037259536 + 140193037259680 [label=ViewBackward0] + 140193037259776 -> 140193037259680 + 140193037259776 [label=AddmmBackward0] + 140193037259872 -> 140193037259776 + 140193037259872 [label=ToCopyBackward0] + 140193037260064 -> 140193037259872 + 140193039537472 [label="encoder.layer.3.experts.dense2.bias + (768)" fillcolor=lightblue] + 140193039537472 -> 140193037260064 + 140193037260064 [label=AccumulateGrad] + 140193037259824 -> 140193037259776 + 140193037259824 [label=ViewBackward0] + 140193037260112 -> 140193037259824 + 140193037260112 [label=GeluBackward0] + 140193037260208 -> 140193037260112 + 140193037260208 [label=ViewBackward0] + 140193037260304 -> 140193037260208 + 140193037260304 [label=AddmmBackward0] + 140193037260400 -> 140193037260304 + 140193037260400 [label=ToCopyBackward0] + 140193037260592 -> 140193037260400 + 140193039537712 [label="encoder.layer.3.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140193039537712 -> 140193037260592 + 140193037260592 [label=AccumulateGrad] + 140193037260352 -> 140193037260304 + 140193037260352 [label=ViewBackward0] + 140193037260640 -> 140193037260352 + 140193037260640 [label=ToCopyBackward0] + 140193037259488 -> 140193037260640 + 140193037259488 [label=SliceBackward0] + 140193037260784 -> 140193037259488 + 140193037260784 [label=SliceBackward0] + 140193037260880 -> 140193037260784 + 140193037260880 [label=SliceBackward0] + 140193037260976 -> 140193037260880 + 140193037260976 [label=SliceBackward0] + 140193037261072 -> 140193037260976 + 140193037261072 [label=SliceBackward0] + 140193037261168 -> 140193037261072 + 140193037261168 [label=NativeLayerNormBackward0] + 140193037261264 -> 140193037261168 + 140193037261264 [label=AddBackward0] + 140193037261456 -> 140193037261264 + 140193037261456 [label=NativeDropoutBackward0] + 140193037261600 -> 140193037261456 + 140193037261600 [label=ViewBackward0] + 140193037261696 -> 140193037261600 + 140193037261696 [label=AddmmBackward0] + 140193037261792 -> 140193037261696 + 140193037261792 [label=ToCopyBackward0] + 140193037261984 -> 140193037261792 + 140193039539632 [label="encoder.layer.3.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039539632 -> 140193037261984 + 140193037261984 [label=AccumulateGrad] + 140193037261744 -> 140193037261696 + 140193037261744 [label=ViewBackward0] + 140193037262032 -> 140193037261744 + 140193037262032 [label=ViewBackward0] + 140193037262128 -> 140193037262032 + 140193037262128 [label=CloneBackward0] + 140193037262224 -> 140193037262128 + 140193037262224 [label=PermuteBackward0] + 140193037262320 -> 140193037262224 + 140193037262320 [label=UnsafeViewBackward0] + 140193037262416 -> 140193037262320 + 140193037262416 [label=BmmBackward0] + 140193037262512 -> 140193037262416 + 140193037262512 [label=ReshapeAliasBackward0] + 140193037262656 -> 140193037262512 + 140193037262656 [label=ExpandBackward0] + 140193037262752 -> 140193037262656 + 140193037262752 [label=ToCopyBackward0] + 140193037262800 -> 140193037262752 + 140193037262800 [label=NativeDropoutBackward0] + 140193035952288 -> 140193037262800 + 140193035952288 [label=SoftmaxBackward0] + 140193035952384 -> 140193035952288 + 140193035952384 [label=AddBackward0] + 140193035952480 -> 140193035952384 + 140193035952480 [label=DivBackward0] + 140193035952576 -> 140193035952480 + 140193035952576 [label=UnsafeViewBackward0] + 140193035952672 -> 140193035952576 + 140193035952672 [label=BmmBackward0] + 140193035952768 -> 140193035952672 + 140193035952768 [label=UnsafeViewBackward0] + 140193035952912 -> 140193035952768 + 140193035952912 [label=CloneBackward0] + 140193035953008 -> 140193035952912 + 140193035953008 [label=ExpandBackward0] + 140193035953104 -> 140193035953008 + 140193035953104 [label=PermuteBackward0] + 140193035953200 -> 140193035953104 + 140193035953200 [label=ViewBackward0] + 140193035953296 -> 140193035953200 + 140193035953296 [label=ViewBackward0] + 140193035953392 -> 140193035953296 + 140193035953392 [label=AddmmBackward0] + 140193035953488 -> 140193035953392 + 140193035953488 [label=ToCopyBackward0] + 140193035953680 -> 140193035953488 + 140193039548640 [label="encoder.layer.3.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039548640 -> 140193035953680 + 140193035953680 [label=AccumulateGrad] + 140193035953440 -> 140193035953392 + 140193035953440 [label=ViewBackward0] + 140193035953728 -> 140193035953440 + 140193035953728 [label=ToCopyBackward0] + 140193037261408 -> 140193035953728 + 140193037261408 [label=CatBackward0] + 140193035953872 -> 140193037261408 + 140193035953872 [label=NativeLayerNormBackward0] + 140193035954016 -> 140193035953872 + 140193035954016 [label=AddBackward0] + 140193035954208 -> 140193035954016 + 140193035954208 [label=NativeDropoutBackward0] + 140193035954352 -> 140193035954208 + 140193035954352 [label=ViewBackward0] + 140193035954448 -> 140193035954352 + 140193035954448 [label=AddmmBackward0] + 140193035954544 -> 140193035954448 + 140193035954544 [label=ToCopyBackward0] + 140193035954736 -> 140193035954544 + 140193039549120 [label="encoder.layer.2.experts.dense2.bias + (768)" fillcolor=lightblue] + 140193039549120 -> 140193035954736 + 140193035954736 [label=AccumulateGrad] + 140193035954496 -> 140193035954448 + 140193035954496 [label=ViewBackward0] + 140193035954784 -> 140193035954496 + 140193035954784 [label=GeluBackward0] + 140193035954880 -> 140193035954784 + 140193035954880 [label=ViewBackward0] + 140193035954976 -> 140193035954880 + 140193035954976 [label=AddmmBackward0] + 140193035955072 -> 140193035954976 + 140193035955072 [label=ToCopyBackward0] + 140193035955264 -> 140193035955072 + 140193039549360 [label="encoder.layer.2.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140193039549360 -> 140193035955264 + 140193035955264 [label=AccumulateGrad] + 140193035955024 -> 140193035954976 + 140193035955024 [label=ViewBackward0] + 140193035955312 -> 140193035955024 + 140193035955312 [label=ToCopyBackward0] + 140193035954160 -> 140193035955312 + 140193035954160 [label=SliceBackward0] + 140193035955456 -> 140193035954160 + 140193035955456 [label=SliceBackward0] + 140193035955552 -> 140193035955456 + 140193035955552 [label=NativeLayerNormBackward0] + 140193035955648 -> 140193035955552 + 140193035955648 [label=AddBackward0] + 140193035955840 -> 140193035955648 + 140193035955840 [label=NativeDropoutBackward0] + 140193035955984 -> 140193035955840 + 140193035955984 [label=ViewBackward0] + 140193035956080 -> 140193035955984 + 140193035956080 [label=AddmmBackward0] + 140193035956176 -> 140193035956080 + 140193035956176 [label=ToCopyBackward0] + 140193035960528 -> 140193035956176 + 140193039551280 [label="encoder.layer.2.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039551280 -> 140193035960528 + 140193035960528 [label=AccumulateGrad] + 140193035956128 -> 140193035956080 + 140193035956128 [label=ViewBackward0] + 140193035960576 -> 140193035956128 + 140193035960576 [label=ViewBackward0] + 140193035960672 -> 140193035960576 + 140193035960672 [label=CloneBackward0] + 140193035960768 -> 140193035960672 + 140193035960768 [label=PermuteBackward0] + 140193035960864 -> 140193035960768 + 140193035960864 [label=UnsafeViewBackward0] + 140193035960960 -> 140193035960864 + 140193035960960 [label=BmmBackward0] + 140193035961056 -> 140193035960960 + 140193035961056 [label=ReshapeAliasBackward0] + 140193035961200 -> 140193035961056 + 140193035961200 [label=ExpandBackward0] + 140193035961296 -> 140193035961200 + 140193035961296 [label=ToCopyBackward0] + 140193035961392 -> 140193035961296 + 140193035961392 [label=NativeDropoutBackward0] + 140193035961488 -> 140193035961392 + 140193035961488 [label=SoftmaxBackward0] + 140193035961584 -> 140193035961488 + 140193035961584 [label=AddBackward0] + 140193035961680 -> 140193035961584 + 140193035961680 [label=DivBackward0] + 140193035961776 -> 140193035961680 + 140193035961776 [label=UnsafeViewBackward0] + 140193035961872 -> 140193035961776 + 140193035961872 [label=BmmBackward0] + 140193035961968 -> 140193035961872 + 140193035961968 [label=UnsafeViewBackward0] + 140193035962112 -> 140193035961968 + 140193035962112 [label=CloneBackward0] + 140193035962208 -> 140193035962112 + 140193035962208 [label=ExpandBackward0] + 140193035962304 -> 140193035962208 + 140193035962304 [label=PermuteBackward0] + 140193035962400 -> 140193035962304 + 140193035962400 [label=ViewBackward0] + 140193035962496 -> 140193035962400 + 140193035962496 [label=ViewBackward0] + 140193035962592 -> 140193035962496 + 140193035962592 [label=AddmmBackward0] + 140193035962688 -> 140193035962592 + 140193035962688 [label=ToCopyBackward0] + 140193035962880 -> 140193035962688 + 140193039552000 [label="encoder.layer.2.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140193039552000 -> 140193035962880 + 140193035962880 [label=AccumulateGrad] + 140193035962640 -> 140193035962592 + 140193035962640 [label=ViewBackward0] + 140193035962928 -> 140193035962640 + 140193035962928 [label=ToCopyBackward0] + 140193035955792 -> 140193035962928 + 140193035955792 [label=SliceBackward0] + 140193035963072 -> 140193035955792 + 140193035963072 [label=SliceBackward0] + 140193035963168 -> 140193035963072 + 140193035963168 [label=SliceBackward0] + 140193035963264 -> 140193035963168 + 140193035963264 [label=NativeLayerNormBackward0] + 140193035963360 -> 140193035963264 + 140193035963360 [label=AddBackward0] + 140193035963552 -> 140193035963360 + 140193035963552 [label=NativeDropoutBackward0] + 140193035963696 -> 140193035963552 + 140193035963696 [label=ViewBackward0] + 140193035963792 -> 140193035963696 + 140193035963792 [label=AddmmBackward0] + 140193035963888 -> 140193035963792 + 140193035963888 [label=ToCopyBackward0] + 140193035964080 -> 140193035963888 + 140193039556672 [label="encoder.layer.2.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039556672 -> 140193035964080 + 140193035964080 [label=AccumulateGrad] + 140193035963840 -> 140193035963792 + 140193035963840 [label=ViewBackward0] + 140193035964128 -> 140193035963840 + 140193035964128 [label=ViewBackward0] + 140193035964224 -> 140193035964128 + 140193035964224 [label=CloneBackward0] + 140193035964320 -> 140193035964224 + 140193035964320 [label=PermuteBackward0] + 140193035964368 -> 140193035964320 + 140193035964368 [label=UnsafeViewBackward0] + 140193035985056 -> 140193035964368 + 140193035985056 [label=BmmBackward0] + 140193035985152 -> 140193035985056 + 140193035985152 [label=ReshapeAliasBackward0] + 140193035985296 -> 140193035985152 + 140193035985296 [label=ExpandBackward0] + 140193035985392 -> 140193035985296 + 140193035985392 [label=ToCopyBackward0] + 140193035985488 -> 140193035985392 + 140193035985488 [label=NativeDropoutBackward0] + 140193035985584 -> 140193035985488 + 140193035985584 [label=SoftmaxBackward0] + 140193035985680 -> 140193035985584 + 140193035985680 [label=AddBackward0] + 140193035985776 -> 140193035985680 + 140193035985776 [label=DivBackward0] + 140193035985872 -> 140193035985776 + 140193035985872 [label=UnsafeViewBackward0] + 140193035985968 -> 140193035985872 + 140193035985968 [label=BmmBackward0] + 140193035986064 -> 140193035985968 + 140193035986064 [label=UnsafeViewBackward0] + 140193035986208 -> 140193035986064 + 140193035986208 [label=CloneBackward0] + 140193035986304 -> 140193035986208 + 140193035986304 [label=ExpandBackward0] + 140193035986400 -> 140193035986304 + 140193035986400 [label=PermuteBackward0] + 140193035986496 -> 140193035986400 + 140193035986496 [label=ViewBackward0] + 140193035986592 -> 140193035986496 + 140193035986592 [label=ViewBackward0] + 140193035986688 -> 140193035986592 + 140193035986688 [label=AddmmBackward0] + 140193035986784 -> 140193035986688 + 140193035986784 [label=ToCopyBackward0] + 140193035986976 -> 140193035986784 + 140193039557392 [label="encoder.layer.2.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039557392 -> 140193035986976 + 140193035986976 [label=AccumulateGrad] + 140193035986736 -> 140193035986688 + 140193035986736 [label=ViewBackward0] + 140193035987024 -> 140193035986736 + 140193035987024 [label=ToCopyBackward0] + 140193035963504 -> 140193035987024 + 140193035963504 [label=CatBackward0] + 140193035987168 -> 140193035963504 + 140193035987168 [label=NativeLayerNormBackward0] + 140193035987312 -> 140193035987168 + 140193035987312 [label=AddBackward0] + 140193035987504 -> 140193035987312 + 140193035987504 [label=NativeDropoutBackward0] + 140193035987648 -> 140193035987504 + 140193035987648 [label=ViewBackward0] + 140193035987744 -> 140193035987648 + 140193035987744 [label=AddmmBackward0] + 140193035987840 -> 140193035987744 + 140193035987840 [label=ToCopyBackward0] + 140193035988032 -> 140193035987840 + 140193039557872 [label="encoder.layer.1.experts.dense2.bias + (768)" fillcolor=lightblue] + 140193039557872 -> 140193035988032 + 140193035988032 [label=AccumulateGrad] + 140193035987792 -> 140193035987744 + 140193035987792 [label=ViewBackward0] + 140193035988080 -> 140193035987792 + 140193035988080 [label=GeluBackward0] + 140193035988176 -> 140193035988080 + 140193035988176 [label=ViewBackward0] + 140193035988272 -> 140193035988176 + 140193035988272 [label=AddmmBackward0] + 140193035988368 -> 140193035988272 + 140193035988368 [label=ToCopyBackward0] + 140193035988560 -> 140193035988368 + 140193039558112 [label="encoder.layer.1.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140193039558112 -> 140193035988560 + 140193035988560 [label=AccumulateGrad] + 140193035988320 -> 140193035988272 + 140193035988320 [label=ViewBackward0] + 140193035988608 -> 140193035988320 + 140193035988608 [label=ToCopyBackward0] + 140193035987456 -> 140193035988608 + 140193035987456 [label=SliceBackward0] + 140193035988752 -> 140193035987456 + 140193035988752 [label=SliceBackward0] + 140193035988848 -> 140193035988752 + 140193035988848 [label=SliceBackward0] + 140193035988944 -> 140193035988848 + 140193035988944 [label=SliceBackward0] + 140193035988464 -> 140193035988944 + 140193035988464 [label=SliceBackward0] + 140193036001488 -> 140193035988464 + 140193036001488 [label=NativeLayerNormBackward0] + 140193036001584 -> 140193036001488 + 140193036001584 [label=AddBackward0] + 140193036001776 -> 140193036001584 + 140193036001776 [label=NativeDropoutBackward0] + 140193036001920 -> 140193036001776 + 140193036001920 [label=ViewBackward0] + 140193036002016 -> 140193036001920 + 140193036002016 [label=AddmmBackward0] + 140193036002112 -> 140193036002016 + 140193036002112 [label=ToCopyBackward0] + 140193036002304 -> 140193036002112 + 140193039560032 [label="encoder.layer.1.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039560032 -> 140193036002304 + 140193036002304 [label=AccumulateGrad] + 140193036002064 -> 140193036002016 + 140193036002064 [label=ViewBackward0] + 140193036002352 -> 140193036002064 + 140193036002352 [label=ViewBackward0] + 140193036002448 -> 140193036002352 + 140193036002448 [label=CloneBackward0] + 140193036002544 -> 140193036002448 + 140193036002544 [label=PermuteBackward0] + 140193036002640 -> 140193036002544 + 140193036002640 [label=UnsafeViewBackward0] + 140193036002736 -> 140193036002640 + 140193036002736 [label=BmmBackward0] + 140193036002832 -> 140193036002736 + 140193036002832 [label=ReshapeAliasBackward0] + 140193036002976 -> 140193036002832 + 140193036002976 [label=ExpandBackward0] + 140193036003072 -> 140193036002976 + 140193036003072 [label=ToCopyBackward0] + 140193036003168 -> 140193036003072 + 140193036003168 [label=NativeDropoutBackward0] + 140193036003264 -> 140193036003168 + 140193036003264 [label=SoftmaxBackward0] + 140193036003360 -> 140193036003264 + 140193036003360 [label=AddBackward0] + 140193036003456 -> 140193036003360 + 140193036003456 [label=DivBackward0] + 140193036003552 -> 140193036003456 + 140193036003552 [label=UnsafeViewBackward0] + 140193036003648 -> 140193036003552 + 140193036003648 [label=BmmBackward0] + 140193036003744 -> 140193036003648 + 140193036003744 [label=UnsafeViewBackward0] + 140193036003888 -> 140193036003744 + 140193036003888 [label=CloneBackward0] + 140193036003984 -> 140193036003888 + 140193036003984 [label=ExpandBackward0] + 140193036004080 -> 140193036003984 + 140193036004080 [label=PermuteBackward0] + 140193036004176 -> 140193036004080 + 140193036004176 [label=ViewBackward0] + 140193036004272 -> 140193036004176 + 140193036004272 [label=ViewBackward0] + 140193036004368 -> 140193036004272 + 140193036004368 [label=AddmmBackward0] + 140193036004464 -> 140193036004368 + 140193036004464 [label=ToCopyBackward0] + 140193036004656 -> 140193036004464 + 140193039577232 [label="encoder.layer.1.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039577232 -> 140193036004656 + 140193036004656 [label=AccumulateGrad] + 140193036004416 -> 140193036004368 + 140193036004416 [label=ViewBackward0] + 140193036004704 -> 140193036004416 + 140193036004704 [label=ToCopyBackward0] + 140193036001728 -> 140193036004704 + 140193036001728 [label=CatBackward0] + 140193036004848 -> 140193036001728 + 140193036004848 [label=NativeLayerNormBackward0] + 140193036004992 -> 140193036004848 + 140193036004992 [label=AddBackward0] + 140193036005184 -> 140193036004992 + 140193036005184 [label=NativeDropoutBackward0] + 140193036005328 -> 140193036005184 + 140193036005328 [label=ViewBackward0] + 140193036005232 -> 140193036005328 + 140193036005232 [label=AddmmBackward0] + 140193036017872 -> 140193036005232 + 140193036017872 [label=ToCopyBackward0] + 140193036018064 -> 140193036017872 + 140193039577712 [label="encoder.layer.0.experts.dense2.bias + (768)" fillcolor=lightblue] + 140193039577712 -> 140193036018064 + 140193036018064 [label=AccumulateGrad] + 140193036017824 -> 140193036005232 + 140193036017824 [label=ViewBackward0] + 140193036018112 -> 140193036017824 + 140193036018112 [label=GeluBackward0] + 140193036018208 -> 140193036018112 + 140193036018208 [label=ViewBackward0] + 140193036018304 -> 140193036018208 + 140193036018304 [label=AddmmBackward0] + 140193036018400 -> 140193036018304 + 140193036018400 [label=ToCopyBackward0] + 140193036018592 -> 140193036018400 + 140193039577952 [label="encoder.layer.0.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140193039577952 -> 140193036018592 + 140193036018592 [label=AccumulateGrad] + 140193036018352 -> 140193036018304 + 140193036018352 [label=ViewBackward0] + 140193036018640 -> 140193036018352 + 140193036018640 [label=ToCopyBackward0] + 140193036005136 -> 140193036018640 + 140193036005136 [label=SliceBackward0] + 140193036018784 -> 140193036005136 + 140193036018784 [label=SliceBackward0] + 140193036018880 -> 140193036018784 + 140193036018880 [label=NativeLayerNormBackward0] + 140193036018976 -> 140193036018880 + 140193036018976 [label=AddBackward0] + 140193036019168 -> 140193036018976 + 140193036019168 [label=NativeDropoutBackward0] + 140193036019312 -> 140193036019168 + 140193036019312 [label=ViewBackward0] + 140193036019408 -> 140193036019312 + 140193036019408 [label=AddmmBackward0] + 140193036019504 -> 140193036019408 + 140193036019504 [label=ToCopyBackward0] + 140193036019696 -> 140193036019504 + 140193039579952 [label="encoder.layer.0.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039579952 -> 140193036019696 + 140193036019696 [label=AccumulateGrad] + 140193036019456 -> 140193036019408 + 140193036019456 [label=ViewBackward0] + 140193036019744 -> 140193036019456 + 140193036019744 [label=ViewBackward0] + 140193036019840 -> 140193036019744 + 140193036019840 [label=CloneBackward0] + 140193036019936 -> 140193036019840 + 140193036019936 [label=PermuteBackward0] + 140193036020032 -> 140193036019936 + 140193036020032 [label=UnsafeViewBackward0] + 140193036020128 -> 140193036020032 + 140193036020128 [label=BmmBackward0] + 140193036020224 -> 140193036020128 + 140193036020224 [label=ReshapeAliasBackward0] + 140193036020368 -> 140193036020224 + 140193036020368 [label=ExpandBackward0] + 140193036020464 -> 140193036020368 + 140193036020464 [label=ToCopyBackward0] + 140193036020560 -> 140193036020464 + 140193036020560 [label=NativeDropoutBackward0] + 140193036020656 -> 140193036020560 + 140193036020656 [label=SoftmaxBackward0] + 140193036020752 -> 140193036020656 + 140193036020752 [label=AddBackward0] + 140193036020848 -> 140193036020752 + 140193036020848 [label=DivBackward0] + 140193036020944 -> 140193036020848 + 140193036020944 [label=UnsafeViewBackward0] + 140193036021040 -> 140193036020944 + 140193036021040 [label=BmmBackward0] + 140193036021136 -> 140193036021040 + 140193036021136 [label=UnsafeViewBackward0] + 140193036021280 -> 140193036021136 + 140193036021280 [label=CloneBackward0] + 140193036021376 -> 140193036021280 + 140193036021376 [label=ExpandBackward0] + 140193036021472 -> 140193036021376 + 140193036021472 [label=PermuteBackward0] + 140193036021568 -> 140193036021472 + 140193036021568 [label=ViewBackward0] + 140193036021664 -> 140193036021568 + 140193036021664 [label=ViewBackward0] + 140193036021712 -> 140193036021664 + 140193036021712 [label=AddmmBackward0] + 140193036034208 -> 140193036021712 + 140193036034208 [label=ToCopyBackward0] + 140193036034400 -> 140193036034208 + 140193039580672 [label="encoder.layer.0.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140193039580672 -> 140193036034400 + 140193036034400 [label=AccumulateGrad] + 140193036034160 -> 140193036021712 + 140193036034160 [label=ViewBackward0] + 140193036034448 -> 140193036034160 + 140193036034448 [label=ToCopyBackward0] + 140193036019120 -> 140193036034448 + 140193036019120 [label=SliceBackward0] + 140193036034592 -> 140193036019120 + 140193036034592 [label=SliceBackward0] + 140193036034688 -> 140193036034592 + 140193036034688 [label=SliceBackward0] + 140193036034784 -> 140193036034688 + 140193036034784 [label=NativeLayerNormBackward0] + 140193036034880 -> 140193036034784 + 140193036034880 [label=AddBackward0] + 140193036035072 -> 140193036034880 + 140193036035072 [label=NativeDropoutBackward0] + 140193036035216 -> 140193036035072 + 140193036035216 [label=ViewBackward0] + 140193036035312 -> 140193036035216 + 140193036035312 [label=AddmmBackward0] + 140193036035408 -> 140193036035312 + 140193036035408 [label=ToCopyBackward0] + 140193036035600 -> 140193036035408 + 140193039581072 [label="encoder.layer.0.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140193039581072 -> 140193036035600 + 140193036035600 [label=AccumulateGrad] + 140193036035360 -> 140193036035312 + 140193036035360 [label=ViewBackward0] + 140193036035648 -> 140193036035360 + 140193036035648 [label=ViewBackward0] + 140193036035744 -> 140193036035648 + 140193036035744 [label=CloneBackward0] + 140193036035840 -> 140193036035744 + 140193036035840 [label=PermuteBackward0] + 140193036035936 -> 140193036035840 + 140193036035936 [label=UnsafeViewBackward0] + 140193036036032 -> 140193036035936 + 140193036036032 [label=BmmBackward0] + 140193036036128 -> 140193036036032 + 140193036036128 [label=ReshapeAliasBackward0] + 140193036036272 -> 140193036036128 + 140193036036272 [label=ExpandBackward0] + 140193036036368 -> 140193036036272 + 140193036036368 [label=ToCopyBackward0] + 140193036036464 -> 140193036036368 + 140193036036464 [label=NativeDropoutBackward0] + 140193036036560 -> 140193036036464 + 140193036036560 [label=SoftmaxBackward0] + 140193036036656 -> 140193036036560 + 140193036036656 [label=AddBackward0] + 140193036036752 -> 140193036036656 + 140193036036752 [label=DivBackward0] + 140193036036848 -> 140193036036752 + 140193036036848 [label=UnsafeViewBackward0] + 140193036036944 -> 140193036036848 + 140193036036944 [label=BmmBackward0] + 140193036037040 -> 140193036036944 + 140193036037040 [label=UnsafeViewBackward0] + 140193036037184 -> 140193036037040 + 140193036037184 [label=CloneBackward0] + 140193036037280 -> 140193036037184 + 140193036037280 [label=ExpandBackward0] + 140193036037376 -> 140193036037280 + 140193036037376 [label=PermuteBackward0] + 140193036037472 -> 140193036037376 + 140193036037472 [label=ViewBackward0] + 140193036037568 -> 140193036037472 + 140193036037568 [label=ViewBackward0] + 140193036037664 -> 140193036037568 + 140193036037664 [label=AddmmBackward0] + 140193036037760 -> 140193036037664 + 140193036037760 [label=ToCopyBackward0] + 140193036037952 -> 140193036037760 + 140193039248416 [label="encoder.layer.0.attention.self.query.bias + (768)" fillcolor=lightblue] + 140193039248416 -> 140193036037952 + 140193036037952 [label=AccumulateGrad] + 140193036037712 -> 140193036037664 + 140193036037712 [label=ViewBackward0] + 140193036038000 -> 140193036037712 + 140193036038000 [label=ToCopyBackward0] + 140193036035024 -> 140193036038000 + 140193036035024 [label=NativeDropoutBackward0] + 140193036037904 -> 140193036035024 + 140193036037904 [label=NativeLayerNormBackward0] + 140193036046400 -> 140193036037904 + 140193036046400 [label=CatBackward0] + 140193036046688 -> 140193036046400 + 140193036046688 [label=ExpandBackward0] + 140193036046832 -> 140193036046688 + 140194225446800 [label=" + (1, 32, 768)" fillcolor=lightblue] + 140194225446800 -> 140193036046832 + 140193036046832 [label=AccumulateGrad] + 140193036046640 -> 140193036046400 + 140193036046640 [label=AddBackward0] + 140193036046880 -> 140193036046640 + 140193036046880 [label=EmbeddingBackward0] + 140193036047024 -> 140193036046880 + 140193039591120 [label="embeddings.word_embeddings.weight + (30523, 768)" fillcolor=lightblue] + 140193039591120 -> 140193036047024 + 140193036047024 [label=AccumulateGrad] + 140193036046928 -> 140193036046640 + 140193036046928 [label=EmbeddingBackward0] + 140193036047072 -> 140193036046928 + 140194041968896 [label="embeddings.position_embeddings.weight + (512, 768)" fillcolor=lightblue] + 140194041968896 -> 140193036047072 + 140193036047072 [label=AccumulateGrad] + 140193036046448 -> 140193036037904 + 140193039245376 [label="embeddings.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039245376 -> 140193036046448 + 140193036046448 [label=AccumulateGrad] + 140193036046496 -> 140193036037904 + 140193039605824 [label="embeddings.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039605824 -> 140193036046496 + 140193036046496 [label=AccumulateGrad] + 140193036037088 -> 140193036037664 + 140193036037088 [label=TBackward0] + 140193036037856 -> 140193036037088 + 140193036037856 [label=ToCopyBackward0] + 140193036038048 -> 140193036037856 + 140193039247696 [label="encoder.layer.0.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039247696 -> 140193036038048 + 140193036038048 [label=AccumulateGrad] + 140193036036992 -> 140193036036944 + 140193036036992 [label=UnsafeViewBackward0] + 140193036037328 -> 140193036036992 + 140193036037328 [label=CloneBackward0] + 140193036037520 -> 140193036037328 + 140193036037520 [label=ExpandBackward0] + 140193036037808 -> 140193036037520 + 140193036037808 [label=TransposeBackward0] + 140193036038096 -> 140193036037808 + 140193036038096 [label=PermuteBackward0] + 140193036046784 -> 140193036038096 + 140193036046784 [label=ViewBackward0] + 140193036047168 -> 140193036046784 + 140193036047168 [label=ViewBackward0] + 140193036046976 -> 140193036047168 + 140193036046976 [label=AddmmBackward0] + 140193036047264 -> 140193036046976 + 140193036047264 [label=ToCopyBackward0] + 140193036047456 -> 140193036047264 + 140193039589920 [label="encoder.layer.0.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039589920 -> 140193036047456 + 140193036047456 [label=AccumulateGrad] + 140193036047216 -> 140193036046976 + 140193036047216 [label=ViewBackward0] + 140193036047504 -> 140193036047216 + 140193036047504 [label=ToCopyBackward0] + 140193036035024 -> 140193036047504 + 140193036046592 -> 140193036046976 + 140193036046592 [label=TBackward0] + 140193036047360 -> 140193036046592 + 140193036047360 [label=ToCopyBackward0] + 140193036047648 -> 140193036047360 + 140193039589840 [label="encoder.layer.0.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039589840 -> 140193036047648 + 140193036047648 [label=AccumulateGrad] + 140193036036080 -> 140193036036032 + 140193036036080 [label=UnsafeViewBackward0] + 140193036036416 -> 140193036036080 + 140193036036416 [label=CloneBackward0] + 140193036036608 -> 140193036036416 + 140193036036608 [label=ExpandBackward0] + 140193036036800 -> 140193036036608 + 140193036036800 [label=PermuteBackward0] + 140193036036176 -> 140193036036800 + 140193036036176 [label=ViewBackward0] + 140193036037424 -> 140193036036176 + 140193036037424 [label=ViewBackward0] + 140193036037136 -> 140193036037424 + 140193036037136 [label=AddmmBackward0] + 140193036036224 -> 140193036037136 + 140193036036224 [label=ToCopyBackward0] + 140193036047408 -> 140193036036224 + 140193039589680 [label="encoder.layer.0.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039589680 -> 140193036047408 + 140193036047408 [label=AccumulateGrad] + 140193036046544 -> 140193036037136 + 140193036046544 [label=ViewBackward0] + 140193036047744 -> 140193036046544 + 140193036047744 [label=ToCopyBackward0] + 140193036035024 -> 140193036047744 + 140193036047120 -> 140193036037136 + 140193036047120 [label=TBackward0] + 140193036047312 -> 140193036047120 + 140193036047312 [label=ToCopyBackward0] + 140193036047792 -> 140193036047312 + 140193039589600 [label="encoder.layer.0.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039589600 -> 140193036047792 + 140193036047792 [label=AccumulateGrad] + 140193036035120 -> 140193036035312 + 140193036035120 [label=TBackward0] + 140193036035792 -> 140193036035120 + 140193036035792 [label=ToCopyBackward0] + 140193036035984 -> 140193036035792 + 140193039589440 [label="encoder.layer.0.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039589440 -> 140193036035984 + 140193036035984 [label=AccumulateGrad] + 140193036035024 -> 140193036034880 + 140193036034832 -> 140193036034784 + 140193039580832 [label="encoder.layer.0.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039580832 -> 140193036034832 + 140193036034832 [label=AccumulateGrad] + 140193036034304 -> 140193036034784 + 140193039580912 [label="encoder.layer.0.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039580912 -> 140193036034304 + 140193036034304 [label=AccumulateGrad] + 140193036034112 -> 140193036021712 + 140193036034112 [label=TBackward0] + 140193036034352 -> 140193036034112 + 140193036034352 [label=ToCopyBackward0] + 140193036034736 -> 140193036034352 + 140193039580592 [label="encoder.layer.0.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039580592 -> 140193036034736 + 140193036034736 [label=AccumulateGrad] + 140193036021088 -> 140193036021040 + 140193036021088 [label=UnsafeViewBackward0] + 140193036021424 -> 140193036021088 + 140193036021424 [label=CloneBackward0] + 140193036021616 -> 140193036021424 + 140193036021616 [label=ExpandBackward0] + 140193036021184 -> 140193036021616 + 140193036021184 [label=TransposeBackward0] + 140193036034640 -> 140193036021184 + 140193036034640 [label=PermuteBackward0] + 140193036034928 -> 140193036034640 + 140193036034928 [label=ViewBackward0] + 140193036035168 -> 140193036034928 + 140193036035168 [label=ViewBackward0] + 140193036035456 -> 140193036035168 + 140193036035456 [label=AddmmBackward0] + 140193036035888 -> 140193036035456 + 140193036035888 [label=ToCopyBackward0] + 140193036036512 -> 140193036035888 + 140193039580432 [label="encoder.layer.0.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140193039580432 -> 140193036036512 + 140193036036512 [label=AccumulateGrad] + 140193036035504 -> 140193036035456 + 140193036035504 [label=ViewBackward0] + 140193036036704 -> 140193036035504 + 140193036036704 [label=ToCopyBackward0] + 140193036037232 -> 140193036036704 + 140193036037232 [label=NativeLayerNormBackward0] + 140193036037616 -> 140193036037232 + 140193039246816 [label=" + (1408)" fillcolor=lightblue] + 140193039246816 -> 140193036037616 + 140193036037616 [label=AccumulateGrad] + 140193036035696 -> 140193036037232 + 140193039247056 [label=" + (1408)" fillcolor=lightblue] + 140193039247056 -> 140193036035696 + 140193036035696 [label=AccumulateGrad] + 140193036034256 -> 140193036035456 + 140193036034256 [label=TBackward0] + 140193036035552 -> 140193036034256 + 140193036035552 [label=ToCopyBackward0] + 140193036047600 -> 140193036035552 + 140193039580352 [label="encoder.layer.0.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140193039580352 -> 140193036047600 + 140193036047600 [label=AccumulateGrad] + 140193036020176 -> 140193036020128 + 140193036020176 [label=UnsafeViewBackward0] + 140193036020512 -> 140193036020176 + 140193036020512 [label=CloneBackward0] + 140193036020704 -> 140193036020512 + 140193036020704 [label=ExpandBackward0] + 140193036020896 -> 140193036020704 + 140193036020896 [label=PermuteBackward0] + 140193036020272 -> 140193036020896 + 140193036020272 [label=ViewBackward0] + 140193036021520 -> 140193036020272 + 140193036021520 [label=ViewBackward0] + 140193036021232 -> 140193036021520 + 140193036021232 [label=AddmmBackward0] + 140193036034496 -> 140193036021232 + 140193036034496 [label=ToCopyBackward0] + 140193036036896 -> 140193036034496 + 140193039580192 [label="encoder.layer.0.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140193039580192 -> 140193036036896 + 140193036036896 [label=AccumulateGrad] + 140193036034976 -> 140193036021232 + 140193036034976 [label=ViewBackward0] + 140193036036320 -> 140193036034976 + 140193036036320 [label=ToCopyBackward0] + 140193036037232 -> 140193036036320 + 140193036034544 -> 140193036021232 + 140193036034544 [label=TBackward0] + 140193036047552 -> 140193036034544 + 140193036047552 [label=ToCopyBackward0] + 140193036047696 -> 140193036047552 + 140193039580112 [label="encoder.layer.0.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140193039580112 -> 140193036047696 + 140193036047696 [label=AccumulateGrad] + 140193036019216 -> 140193036019408 + 140193036019216 [label=TBackward0] + 140193036019888 -> 140193036019216 + 140193036019888 [label=ToCopyBackward0] + 140193036020080 -> 140193036019888 + 140193039579872 [label="encoder.layer.0.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039579872 -> 140193036020080 + 140193036020080 [label=AccumulateGrad] + 140193036019120 -> 140193036018976 + 140193036018928 -> 140193036018880 + 140193039579632 [label="encoder.layer.0.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039579632 -> 140193036018928 + 140193036018928 [label=AccumulateGrad] + 140193036018496 -> 140193036018880 + 140193039579712 [label="encoder.layer.0.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039579712 -> 140193036018496 + 140193036018496 [label=AccumulateGrad] + 140193036018016 -> 140193036018304 + 140193036018016 [label=TBackward0] + 140193036018544 -> 140193036018016 + 140193036018544 [label=ToCopyBackward0] + 140193036019024 -> 140193036018544 + 140194226510320 [label="encoder.layer.0.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140194226510320 -> 140193036019024 + 140193036019024 [label=AccumulateGrad] + 140193036017728 -> 140193036005232 + 140193036017728 [label=TBackward0] + 140193036018256 -> 140193036017728 + 140193036018256 [label=ToCopyBackward0] + 140193036018736 -> 140193036018256 + 140193039578032 [label="encoder.layer.0.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039578032 -> 140193036018736 + 140193036018736 [label=AccumulateGrad] + 140193036005136 -> 140193036004992 + 140193036004944 -> 140193036004848 + 140193039577792 [label="encoder.layer.0.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039577792 -> 140193036004944 + 140193036004944 [label=AccumulateGrad] + 140193036004896 -> 140193036004848 + 140193039577472 [label="encoder.layer.0.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039577472 -> 140193036004896 + 140193036004896 [label=AccumulateGrad] + 140193036004608 -> 140193036001728 + 140193036004608 [label=NativeLayerNormBackward0] + 140193036005280 -> 140193036004608 + 140193036005280 [label=AddBackward0] + 140193036018448 -> 140193036005280 + 140193036018448 [label=NativeDropoutBackward0] + 140193036018160 -> 140193036018448 + 140193036018160 [label=ViewBackward0] + 140193036018688 -> 140193036018160 + 140193036018688 [label=AddmmBackward0] + 140193036019552 -> 140193036018688 + 140193036019552 [label=ToCopyBackward0] + 140193036019648 -> 140193036019552 + 140193039579232 [label="encoder.layer.0.output.dense.bias + (768)" fillcolor=lightblue] + 140193039579232 -> 140193036019648 + 140193036019648 [label=AccumulateGrad] + 140193036019360 -> 140193036018688 + 140193036019360 [label=ViewBackward0] + 140193036019792 -> 140193036019360 + 140193036019792 [label=GeluBackward0] + 140193036020800 -> 140193036019792 + 140193036020800 [label=ViewBackward0] + 140193036021328 -> 140193036020800 + 140193036021328 [label=AddmmBackward0] + 140193036035264 -> 140193036021328 + 140193036035264 [label=ToCopyBackward0] + 140193036047936 -> 140193036035264 + 140193039579472 [label="encoder.layer.0.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039579472 -> 140193036047936 + 140193036047936 [label=AccumulateGrad] + 140193036020320 -> 140193036021328 + 140193036020320 [label=ViewBackward0] + 140193036047840 -> 140193036020320 + 140193036047840 [label=ToCopyBackward0] + 140193036017968 -> 140193036047840 + 140193036017968 [label=SliceBackward0] + 140193036048128 -> 140193036017968 + 140193036048128 [label=SliceBackward0] + 140193036048224 -> 140193036048128 + 140193036048224 [label=SliceBackward0] + 140193036034784 -> 140193036048224 + 140193036020416 -> 140193036021328 + 140193036020416 [label=TBackward0] + 140193036047984 -> 140193036020416 + 140193036047984 [label=ToCopyBackward0] + 140193036048320 -> 140193036047984 + 140193039579392 [label="encoder.layer.0.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039579392 -> 140193036048320 + 140193036048320 [label=AccumulateGrad] + 140193036019264 -> 140193036018688 + 140193036019264 [label=TBackward0] + 140193036020992 -> 140193036019264 + 140193036020992 [label=ToCopyBackward0] + 140193036020608 -> 140193036020992 + 140193039579152 [label="encoder.layer.0.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039579152 -> 140193036020608 + 140193036020608 [label=AccumulateGrad] + 140193036017968 -> 140193036005280 + 140193036005088 -> 140193036004608 + 140193039578912 [label="encoder.layer.0.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039578912 -> 140193036005088 + 140193036005088 [label=AccumulateGrad] + 140193036005040 -> 140193036004608 + 140193039578992 [label="encoder.layer.0.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039578992 -> 140193036005040 + 140193036005040 [label=AccumulateGrad] + 140193036003792 -> 140193036004368 + 140193036003792 [label=TBackward0] + 140193036004560 -> 140193036003792 + 140193036004560 [label=ToCopyBackward0] + 140193036004752 -> 140193036004560 + 140193039577552 [label="encoder.layer.1.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039577552 -> 140193036004752 + 140193036004752 [label=AccumulateGrad] + 140193036003696 -> 140193036003648 + 140193036003696 [label=UnsafeViewBackward0] + 140193036004032 -> 140193036003696 + 140193036004032 [label=CloneBackward0] + 140193036004224 -> 140193036004032 + 140193036004224 [label=ExpandBackward0] + 140193036004512 -> 140193036004224 + 140193036004512 [label=TransposeBackward0] + 140193036004800 -> 140193036004512 + 140193036004800 [label=PermuteBackward0] + 140193036018832 -> 140193036004800 + 140193036018832 [label=ViewBackward0] + 140193036019600 -> 140193036018832 + 140193036019600 [label=ViewBackward0] + 140193036019984 -> 140193036019600 + 140193036019984 [label=AddmmBackward0] + 140193036048176 -> 140193036019984 + 140193036048176 [label=ToCopyBackward0] + 140193036048368 -> 140193036048176 + 140193039560512 [label="encoder.layer.1.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039560512 -> 140193036048368 + 140193036048368 [label=AccumulateGrad] + 140193036047888 -> 140193036019984 + 140193036047888 [label=ViewBackward0] + 140193036048416 -> 140193036047888 + 140193036048416 [label=ToCopyBackward0] + 140193036001728 -> 140193036048416 + 140193036048080 -> 140193036019984 + 140193036048080 [label=TBackward0] + 140193036048272 -> 140193036048080 + 140193036048272 [label=ToCopyBackward0] + 140193036048560 -> 140193036048272 + 140193039577312 [label="encoder.layer.1.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039577312 -> 140193036048560 + 140193036048560 [label=AccumulateGrad] + 140193036002784 -> 140193036002736 + 140193036002784 [label=UnsafeViewBackward0] + 140193036003120 -> 140193036002784 + 140193036003120 [label=CloneBackward0] + 140193036003312 -> 140193036003120 + 140193036003312 [label=ExpandBackward0] + 140193036003504 -> 140193036003312 + 140193036003504 [label=PermuteBackward0] + 140193036002880 -> 140193036003504 + 140193036002880 [label=ViewBackward0] + 140193036004128 -> 140193036002880 + 140193036004128 [label=ViewBackward0] + 140193036003840 -> 140193036004128 + 140193036003840 [label=AddmmBackward0] + 140193036002928 -> 140193036003840 + 140193036002928 [label=ToCopyBackward0] + 140193036048032 -> 140193036002928 + 140193039560272 [label="encoder.layer.1.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039560272 -> 140193036048032 + 140193036048032 [label=AccumulateGrad] + 140193036017776 -> 140193036003840 + 140193036017776 [label=ViewBackward0] + 140193036048656 -> 140193036017776 + 140193036048656 [label=ToCopyBackward0] + 140193036001728 -> 140193036048656 + 140193036019072 -> 140193036003840 + 140193036019072 [label=TBackward0] + 140193036046736 -> 140193036019072 + 140193036046736 [label=ToCopyBackward0] + 140193036048704 -> 140193036046736 + 140193039560592 [label="encoder.layer.1.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039560592 -> 140193036048704 + 140193036048704 [label=AccumulateGrad] + 140193036001824 -> 140193036002016 + 140193036001824 [label=TBackward0] + 140193036002496 -> 140193036001824 + 140193036002496 [label=ToCopyBackward0] + 140193036002688 -> 140193036002496 + 140193039560352 [label="encoder.layer.1.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039560352 -> 140193036002688 + 140193036002688 [label=AccumulateGrad] + 140193036001728 -> 140193036001584 + 140193036001536 -> 140193036001488 + 140193039560112 [label="encoder.layer.1.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039560112 -> 140193036001536 + 140193036001536 [label=AccumulateGrad] + 140193036001344 -> 140193036001488 + 140193039559792 [label="encoder.layer.1.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039559792 -> 140193036001344 + 140193036001344 [label=AccumulateGrad] + 140193035987984 -> 140193035988272 + 140193035987984 [label=TBackward0] + 140193035988512 -> 140193035987984 + 140193035988512 [label=ToCopyBackward0] + 140193035988896 -> 140193035988512 + 140193039558432 [label="encoder.layer.1.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039558432 -> 140193035988896 + 140193035988896 [label=AccumulateGrad] + 140193035987552 -> 140193035987744 + 140193035987552 [label=TBackward0] + 140193035988224 -> 140193035987552 + 140193035988224 [label=ToCopyBackward0] + 140193035988704 -> 140193035988224 + 140193039558192 [label="encoder.layer.1.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039558192 -> 140193035988704 + 140193035988704 [label=AccumulateGrad] + 140193035987456 -> 140193035987312 + 140193035987264 -> 140193035987168 + 140193039557952 [label="encoder.layer.1.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039557952 -> 140193035987264 + 140193035987264 [label=AccumulateGrad] + 140193035987216 -> 140193035987168 + 140193039557632 [label="encoder.layer.1.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039557632 -> 140193035987216 + 140193035987216 [label=AccumulateGrad] + 140193035986928 -> 140193035963504 + 140193035986928 [label=NativeLayerNormBackward0] + 140193035987600 -> 140193035986928 + 140193035987600 [label=AddBackward0] + 140193035988416 -> 140193035987600 + 140193035988416 [label=NativeDropoutBackward0] + 140193035988128 -> 140193035988416 + 140193035988128 [label=ViewBackward0] + 140193035988656 -> 140193035988128 + 140193035988656 [label=AddmmBackward0] + 140193036001680 -> 140193035988656 + 140193036001680 [label=ToCopyBackward0] + 140193036002208 -> 140193036001680 + 140193039559312 [label="encoder.layer.1.output.dense.bias + (768)" fillcolor=lightblue] + 140193039559312 -> 140193036002208 + 140193036002208 [label=AccumulateGrad] + 140193036001632 -> 140193035988656 + 140193036001632 [label=ViewBackward0] + 140193036002592 -> 140193036001632 + 140193036002592 [label=GeluBackward0] + 140193036002256 -> 140193036002592 + 140193036002256 [label=ViewBackward0] + 140193036003216 -> 140193036002256 + 140193036003216 [label=AddmmBackward0] + 140193036003600 -> 140193036003216 + 140193036003600 [label=ToCopyBackward0] + 140193036017920 -> 140193036003600 + 140193039559552 [label="encoder.layer.1.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039559552 -> 140193036017920 + 140193036017920 [label=AccumulateGrad] + 140193036003408 -> 140193036003216 + 140193036003408 [label=ViewBackward0] + 140193036004320 -> 140193036003408 + 140193036004320 [label=ToCopyBackward0] + 140193035987936 -> 140193036004320 + 140193035987936 [label=SliceBackward0] + 140193036048608 -> 140193035987936 + 140193036048608 [label=SliceBackward0] + 140193036048896 -> 140193036048608 + 140193036048896 [label=SliceBackward0] + 140193036001488 -> 140193036048896 + 140193036002160 -> 140193036003216 + 140193036002160 [label=TBackward0] + 140193036048800 -> 140193036002160 + 140193036048800 [label=ToCopyBackward0] + 140193036048992 -> 140193036048800 + 140193039559872 [label="encoder.layer.1.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039559872 -> 140193036048992 + 140193036048992 [label=AccumulateGrad] + 140193036001440 -> 140193035988656 + 140193036001440 [label=TBackward0] + 140193036002400 -> 140193036001440 + 140193036002400 [label=ToCopyBackward0] + 140193036003936 -> 140193036002400 + 140193039559632 [label="encoder.layer.1.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039559632 -> 140193036003936 + 140193036003936 [label=AccumulateGrad] + 140193035987936 -> 140193035987600 + 140193035987408 -> 140193035986928 + 140193039559392 [label="encoder.layer.1.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039559392 -> 140193035987408 + 140193035987408 [label=AccumulateGrad] + 140193035987360 -> 140193035986928 + 140193039559072 [label="encoder.layer.1.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039559072 -> 140193035987360 + 140193035987360 [label=AccumulateGrad] + 140193035986112 -> 140193035986688 + 140193035986112 [label=TBackward0] + 140193035986880 -> 140193035986112 + 140193035986880 [label=ToCopyBackward0] + 140193035987888 -> 140193035986880 + 140193039557712 [label="encoder.layer.2.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039557712 -> 140193035987888 + 140193035987888 [label=AccumulateGrad] + 140193035986016 -> 140193035985968 + 140193035986016 [label=UnsafeViewBackward0] + 140193035986352 -> 140193035986016 + 140193035986352 [label=CloneBackward0] + 140193035986544 -> 140193035986352 + 140193035986544 [label=ExpandBackward0] + 140193035986832 -> 140193035986544 + 140193035986832 [label=TransposeBackward0] + 140193035987696 -> 140193035986832 + 140193035987696 [label=PermuteBackward0] + 140193035987072 -> 140193035987696 + 140193035987072 [label=ViewBackward0] + 140193035986160 -> 140193035987072 + 140193035986160 [label=ViewBackward0] + 140193036003024 -> 140193035986160 + 140193036003024 [label=AddmmBackward0] + 140193036001392 -> 140193036003024 + 140193036001392 [label=ToCopyBackward0] + 140193036049040 -> 140193036001392 + 140193039557152 [label="encoder.layer.2.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039557152 -> 140193036049040 + 140193036049040 [label=AccumulateGrad] + 140193036048848 -> 140193036003024 + 140193036048848 [label=ViewBackward0] + 140193036049088 -> 140193036048848 + 140193036049088 [label=ToCopyBackward0] + 140193035963504 -> 140193036049088 + 140193036048464 -> 140193036003024 + 140193036048464 [label=TBackward0] + 140193036048944 -> 140193036048464 + 140193036048944 [label=ToCopyBackward0] + 140193036049232 -> 140193036048944 + 140193039557472 [label="encoder.layer.2.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039557472 -> 140193036049232 + 140193036049232 [label=AccumulateGrad] + 140193035985104 -> 140193035985056 + 140193035985104 [label=UnsafeViewBackward0] + 140193035985440 -> 140193035985104 + 140193035985440 [label=CloneBackward0] + 140193035985632 -> 140193035985440 + 140193035985632 [label=ExpandBackward0] + 140193035985824 -> 140193035985632 + 140193035985824 [label=PermuteBackward0] + 140193035985200 -> 140193035985824 + 140193035985200 [label=ViewBackward0] + 140193035986448 -> 140193035985200 + 140193035986448 [label=ViewBackward0] + 140193035987120 -> 140193035986448 + 140193035987120 [label=AddmmBackward0] + 140193035988800 -> 140193035987120 + 140193035988800 [label=ToCopyBackward0] + 140193036048512 -> 140193035988800 + 140193039556912 [label="encoder.layer.2.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039556912 -> 140193036048512 + 140193036048512 [label=AccumulateGrad] + 140193035985248 -> 140193035987120 + 140193035985248 [label=ViewBackward0] + 140193036049328 -> 140193035985248 + 140193036049328 [label=ToCopyBackward0] + 140193035963504 -> 140193036049328 + 140193036001872 -> 140193035987120 + 140193036001872 [label=TBackward0] + 140193036048752 -> 140193036001872 + 140193036048752 [label=ToCopyBackward0] + 140193036049376 -> 140193036048752 + 140193039557232 [label="encoder.layer.2.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039557232 -> 140193036049376 + 140193036049376 [label=AccumulateGrad] + 140193035963600 -> 140193035963792 + 140193035963600 [label=TBackward0] + 140193035964272 -> 140193035963600 + 140193035964272 [label=ToCopyBackward0] + 140193035964032 -> 140193035964272 + 140193039556992 [label="encoder.layer.2.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039556992 -> 140193035964032 + 140193035964032 [label=AccumulateGrad] + 140193035963504 -> 140193035963360 + 140193035963312 -> 140193035963264 + 140193039556752 [label="encoder.layer.2.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039556752 -> 140193035963312 + 140193035963312 [label=AccumulateGrad] + 140193035962784 -> 140193035963264 + 140193039552240 [label="encoder.layer.2.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039552240 -> 140193035962784 + 140193035962784 [label=AccumulateGrad] + 140193035962016 -> 140193035962592 + 140193035962016 [label=TBackward0] + 140193035962832 -> 140193035962016 + 140193035962832 [label=ToCopyBackward0] + 140193035963216 -> 140193035962832 + 140193039552320 [label="encoder.layer.2.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039552320 -> 140193035963216 + 140193035963216 [label=AccumulateGrad] + 140193035961920 -> 140193035961872 + 140193035961920 [label=UnsafeViewBackward0] + 140193035962256 -> 140193035961920 + 140193035962256 [label=CloneBackward0] + 140193035962448 -> 140193035962256 + 140193035962448 [label=ExpandBackward0] + 140193035962736 -> 140193035962448 + 140193035962736 [label=TransposeBackward0] + 140193035963120 -> 140193035962736 + 140193035963120 [label=PermuteBackward0] + 140193035963408 -> 140193035963120 + 140193035963408 [label=ViewBackward0] + 140193035963648 -> 140193035963408 + 140193035963648 [label=ViewBackward0] + 140193035963936 -> 140193035963648 + 140193035963936 [label=AddmmBackward0] + 140193035964176 -> 140193035963936 + 140193035964176 [label=ToCopyBackward0] + 140193035985536 -> 140193035964176 + 140193039551760 [label="encoder.layer.2.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140193039551760 -> 140193035985536 + 140193035985536 [label=AccumulateGrad] + 140193035963984 -> 140193035963936 + 140193035963984 [label=ViewBackward0] + 140193035985728 -> 140193035963984 + 140193035985728 [label=ToCopyBackward0] + 140193036037232 -> 140193035985728 + 140193035962064 -> 140193035963936 + 140193035962064 [label=TBackward0] + 140193035984960 -> 140193035962064 + 140193035984960 [label=ToCopyBackward0] + 140193035986640 -> 140193035984960 + 140193039552080 [label="encoder.layer.2.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140193039552080 -> 140193035986640 + 140193035986640 [label=AccumulateGrad] + 140193035961008 -> 140193035960960 + 140193035961008 [label=UnsafeViewBackward0] + 140193035961344 -> 140193035961008 + 140193035961344 [label=CloneBackward0] + 140193035961536 -> 140193035961344 + 140193035961536 [label=ExpandBackward0] + 140193035961728 -> 140193035961536 + 140193035961728 [label=PermuteBackward0] + 140193035961104 -> 140193035961728 + 140193035961104 [label=ViewBackward0] + 140193035962352 -> 140193035961104 + 140193035962352 [label=ViewBackward0] + 140193035963024 -> 140193035962352 + 140193035963024 [label=AddmmBackward0] + 140193035962976 -> 140193035963024 + 140193035962976 [label=ToCopyBackward0] + 140193036001968 -> 140193035962976 + 140193039551520 [label="encoder.layer.2.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140193039551520 -> 140193036001968 + 140193036001968 [label=AccumulateGrad] + 140193035963456 -> 140193035963024 + 140193035963456 [label=ViewBackward0] + 140193035985008 -> 140193035963456 + 140193035985008 [label=ToCopyBackward0] + 140193036037232 -> 140193035985008 + 140193035961152 -> 140193035963024 + 140193035961152 [label=TBackward0] + 140193035986256 -> 140193035961152 + 140193035986256 [label=ToCopyBackward0] + 140193035985920 -> 140193035986256 + 140193039551840 [label="encoder.layer.2.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140193039551840 -> 140193035985920 + 140193035985920 [label=AccumulateGrad] + 140193035955888 -> 140193035956080 + 140193035955888 [label=TBackward0] + 140193035960720 -> 140193035955888 + 140193035960720 [label=ToCopyBackward0] + 140193035960912 -> 140193035960720 + 140193039551600 [label="encoder.layer.2.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039551600 -> 140193035960912 + 140193035960912 [label=AccumulateGrad] + 140193035955792 -> 140193035955648 + 140193035955600 -> 140193035955552 + 140193039551360 [label="encoder.layer.2.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039551360 -> 140193035955600 + 140193035955600 [label=AccumulateGrad] + 140193035955168 -> 140193035955552 + 140193039551040 [label="encoder.layer.2.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039551040 -> 140193035955168 + 140193035955168 [label=AccumulateGrad] + 140193035954688 -> 140193035954976 + 140193035954688 [label=TBackward0] + 140193035955216 -> 140193035954688 + 140193035955216 [label=ToCopyBackward0] + 140193035955696 -> 140193035955216 + 140193039549680 [label="encoder.layer.2.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039549680 -> 140193035955696 + 140193035955696 [label=AccumulateGrad] + 140193035954256 -> 140193035954448 + 140193035954256 [label=TBackward0] + 140193035954928 -> 140193035954256 + 140193035954928 [label=ToCopyBackward0] + 140193035955408 -> 140193035954928 + 140193039549440 [label="encoder.layer.2.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039549440 -> 140193035955408 + 140193035955408 [label=AccumulateGrad] + 140193035954160 -> 140193035954016 + 140193035953968 -> 140193035953872 + 140193039549200 [label="encoder.layer.2.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039549200 -> 140193035953968 + 140193035953968 [label=AccumulateGrad] + 140193035953920 -> 140193035953872 + 140193039548880 [label="encoder.layer.2.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039548880 -> 140193035953920 + 140193035953920 [label=AccumulateGrad] + 140193035953632 -> 140193037261408 + 140193035953632 [label=NativeLayerNormBackward0] + 140193035954304 -> 140193035953632 + 140193035954304 [label=AddBackward0] + 140193035955120 -> 140193035954304 + 140193035955120 [label=NativeDropoutBackward0] + 140193035954832 -> 140193035955120 + 140193035954832 [label=ViewBackward0] + 140193035955360 -> 140193035954832 + 140193035955360 [label=AddmmBackward0] + 140193035956032 -> 140193035955360 + 140193035956032 [label=ToCopyBackward0] + 140193035960480 -> 140193035956032 + 140193039550560 [label="encoder.layer.2.output.dense.bias + (768)" fillcolor=lightblue] + 140193039550560 -> 140193035960480 + 140193035960480 [label=AccumulateGrad] + 140193035955936 -> 140193035955360 + 140193035955936 [label=ViewBackward0] + 140193035960624 -> 140193035955936 + 140193035960624 [label=GeluBackward0] + 140193035961632 -> 140193035960624 + 140193035961632 [label=ViewBackward0] + 140193035962160 -> 140193035961632 + 140193035962160 [label=AddmmBackward0] + 140193035963744 -> 140193035962160 + 140193035963744 [label=ToCopyBackward0] + 140193036049472 -> 140193035963744 + 140193039550800 [label="encoder.layer.2.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039550800 -> 140193036049472 + 140193036049472 [label=AccumulateGrad] + 140193035962544 -> 140193035962160 + 140193035962544 [label=ViewBackward0] + 140193036049136 -> 140193035962544 + 140193036049136 [label=ToCopyBackward0] + 140193035954640 -> 140193036049136 + 140193035954640 [label=SliceBackward0] + 140193036049568 -> 140193035954640 + 140193036049568 [label=SliceBackward0] + 140193036049664 -> 140193036049568 + 140193036049664 [label=SliceBackward0] + 140193035963264 -> 140193036049664 + 140193035961248 -> 140193035962160 + 140193035961248 [label=TBackward0] + 140193036049424 -> 140193035961248 + 140193036049424 [label=ToCopyBackward0] + 140193036049760 -> 140193036049424 + 140193039551120 [label="encoder.layer.2.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039551120 -> 140193036049760 + 140193036049760 [label=AccumulateGrad] + 140193035960384 -> 140193035955360 + 140193035960384 [label=TBackward0] + 140193035961824 -> 140193035960384 + 140193035961824 [label=ToCopyBackward0] + 140193035985344 -> 140193035961824 + 140193039550880 [label="encoder.layer.2.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039550880 -> 140193035985344 + 140193035985344 [label=AccumulateGrad] + 140193035954640 -> 140193035954304 + 140193035954112 -> 140193035953632 + 140193039550640 [label="encoder.layer.2.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039550640 -> 140193035954112 + 140193035954112 [label=AccumulateGrad] + 140193035954064 -> 140193035953632 + 140193039550320 [label="encoder.layer.2.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039550320 -> 140193035954064 + 140193035954064 [label=AccumulateGrad] + 140193035952816 -> 140193035953392 + 140193035952816 [label=TBackward0] + 140193035953584 -> 140193035952816 + 140193035953584 [label=ToCopyBackward0] + 140193035954592 -> 140193035953584 + 140193039548960 [label="encoder.layer.3.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039548960 -> 140193035954592 + 140193035954592 [label=AccumulateGrad] + 140193035952720 -> 140193035952672 + 140193035952720 [label=UnsafeViewBackward0] + 140193035953056 -> 140193035952720 + 140193035953056 [label=CloneBackward0] + 140193035953248 -> 140193035953056 + 140193035953248 [label=ExpandBackward0] + 140193035953536 -> 140193035953248 + 140193035953536 [label=TransposeBackward0] + 140193035954400 -> 140193035953536 + 140193035954400 [label=PermuteBackward0] + 140193035955504 -> 140193035954400 + 140193035955504 [label=ViewBackward0] + 140193035953776 -> 140193035955504 + 140193035953776 [label=ViewBackward0] + 140193035961440 -> 140193035953776 + 140193035961440 [label=AddmmBackward0] + 140193035960432 -> 140193035961440 + 140193035960432 [label=ToCopyBackward0] + 140193036049808 -> 140193035960432 + 140193039548480 [label="encoder.layer.3.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039548480 -> 140193036049808 + 140193036049808 [label=AccumulateGrad] + 140193036049616 -> 140193035961440 + 140193036049616 [label=ViewBackward0] + 140193036049856 -> 140193036049616 + 140193036049856 [label=ToCopyBackward0] + 140193037261408 -> 140193036049856 + 140193036049184 -> 140193035961440 + 140193036049184 [label=TBackward0] + 140193036049712 -> 140193036049184 + 140193036049712 [label=ToCopyBackward0] + 140193036050000 -> 140193036049712 + 140193039548720 [label="encoder.layer.3.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039548720 -> 140193036050000 + 140193036050000 [label=AccumulateGrad] + 140193037262464 -> 140193037262416 + 140193037262464 [label=UnsafeViewBackward0] + 140193037262560 -> 140193037262464 + 140193037262560 [label=CloneBackward0] + 140193037262608 -> 140193037262560 + 140193037262608 [label=ExpandBackward0] + 140193035952528 -> 140193037262608 + 140193035952528 [label=PermuteBackward0] + 140193035952192 -> 140193035952528 + 140193035952192 [label=ViewBackward0] + 140193035953152 -> 140193035952192 + 140193035953152 [label=ViewBackward0] + 140193035953824 -> 140193035953152 + 140193035953824 [label=AddmmBackward0] + 140193035952864 -> 140193035953824 + 140193035952864 [label=ToCopyBackward0] + 140193036049280 -> 140193035952864 + 140193039539872 [label="encoder.layer.3.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039539872 -> 140193036049280 + 140193036049280 [label=AccumulateGrad] + 140193035955744 -> 140193035953824 + 140193035955744 [label=ViewBackward0] + 140193036050096 -> 140193035955744 + 140193036050096 [label=ToCopyBackward0] + 140193037261408 -> 140193036050096 + 140193035952240 -> 140193035953824 + 140193035952240 [label=TBackward0] + 140193036049520 -> 140193035952240 + 140193036049520 [label=ToCopyBackward0] + 140193036050144 -> 140193036049520 + 140193039540112 [label="encoder.layer.3.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039540112 -> 140193036050144 + 140193036050144 [label=AccumulateGrad] + 140193037261504 -> 140193037261696 + 140193037261504 [label=TBackward0] + 140193037262176 -> 140193037261504 + 140193037262176 [label=ToCopyBackward0] + 140193037262368 -> 140193037262176 + 140193039539952 [label="encoder.layer.3.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039539952 -> 140193037262368 + 140193037262368 [label=AccumulateGrad] + 140193037261408 -> 140193037261264 + 140193037261216 -> 140193037261168 + 140193039539712 [label="encoder.layer.3.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039539712 -> 140193037261216 + 140193037261216 [label=AccumulateGrad] + 140193037260496 -> 140193037261168 + 140193039539392 [label="encoder.layer.3.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039539392 -> 140193037260496 + 140193037260496 [label=AccumulateGrad] + 140193037260016 -> 140193037260304 + 140193037260016 [label=TBackward0] + 140193037260544 -> 140193037260016 + 140193037260544 [label=ToCopyBackward0] + 140193037260928 -> 140193037260544 + 140193039538032 [label="encoder.layer.3.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039538032 -> 140193037260928 + 140193037260928 [label=AccumulateGrad] + 140193037259584 -> 140193037259776 + 140193037259584 [label=TBackward0] + 140193037260256 -> 140193037259584 + 140193037260256 [label=ToCopyBackward0] + 140193037260736 -> 140193037260256 + 140193039537792 [label="encoder.layer.3.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039537792 -> 140193037260736 + 140193037260736 [label=AccumulateGrad] + 140193037259488 -> 140193037259344 + 140193037259296 -> 140193037259200 + 140193039537552 [label="encoder.layer.3.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039537552 -> 140193037259296 + 140193037259296 [label=AccumulateGrad] + 140193037259248 -> 140193037259200 + 140193039537232 [label="encoder.layer.3.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039537232 -> 140193037259248 + 140193037259248 [label=AccumulateGrad] + 140193037258960 -> 140193037276496 + 140193037258960 [label=NativeLayerNormBackward0] + 140193037259632 -> 140193037258960 + 140193037259632 [label=AddBackward0] + 140193037260448 -> 140193037259632 + 140193037260448 [label=NativeDropoutBackward0] + 140193037260160 -> 140193037260448 + 140193037260160 [label=ViewBackward0] + 140193037260688 -> 140193037260160 + 140193037260688 [label=AddmmBackward0] + 140193037261360 -> 140193037260688 + 140193037261360 [label=ToCopyBackward0] + 140193037261888 -> 140193037261360 + 140193039538912 [label="encoder.layer.3.output.dense.bias + (768)" fillcolor=lightblue] + 140193039538912 -> 140193037261888 + 140193037261888 [label=AccumulateGrad] + 140193037261312 -> 140193037260688 + 140193037261312 [label=ViewBackward0] + 140193037262272 -> 140193037261312 + 140193037262272 [label=GeluBackward0] + 140193037261936 -> 140193037262272 + 140193037261936 [label=ViewBackward0] + 140193037262080 -> 140193037261936 + 140193037262080 [label=AddmmBackward0] + 140193035952624 -> 140193037262080 + 140193035952624 [label=ToCopyBackward0] + 140193035960816 -> 140193035952624 + 140193039539152 [label="encoder.layer.3.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039539152 -> 140193035960816 + 140193035960816 [label=AccumulateGrad] + 140193035952432 -> 140193037262080 + 140193035952432 [label=ViewBackward0] + 140193035953344 -> 140193035952432 + 140193035953344 [label=ToCopyBackward0] + 140193037259968 -> 140193035953344 + 140193037259968 [label=SliceBackward0] + 140193036050048 -> 140193037259968 + 140193036050048 [label=SliceBackward0] + 140193036050336 -> 140193036050048 + 140193036050336 [label=SliceBackward0] + 140193037261168 -> 140193036050336 + 140193035952336 -> 140193037262080 + 140193035952336 [label=TBackward0] + 140193036050240 -> 140193035952336 + 140193036050240 [label=ToCopyBackward0] + 140193036050384 -> 140193036050240 + 140193039539472 [label="encoder.layer.3.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039539472 -> 140193036050384 + 140193036050384 [label=AccumulateGrad] + 140193037261120 -> 140193037260688 + 140193037261120 [label=TBackward0] + 140193037261840 -> 140193037261120 + 140193037261840 [label=ToCopyBackward0] + 140193035952960 -> 140193037261840 + 140193039539232 [label="encoder.layer.3.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039539232 -> 140193035952960 + 140193035952960 [label=AccumulateGrad] + 140193037259968 -> 140193037259632 + 140193037259440 -> 140193037258960 + 140193039538992 [label="encoder.layer.3.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039538992 -> 140193037259440 + 140193037259440 [label=AccumulateGrad] + 140193037259392 -> 140193037258960 + 140193039538672 [label="encoder.layer.3.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039538672 -> 140193037259392 + 140193037259392 [label=AccumulateGrad] + 140193037258816 -> 140193037279136 + 140193037258816 [label=TBackward0] + 140193037258912 -> 140193037258816 + 140193037258912 [label=ToCopyBackward0] + 140193037259920 -> 140193037258912 + 140193039537312 [label="encoder.layer.4.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039537312 -> 140193037259920 + 140193037259920 [label=AccumulateGrad] + 140193037278464 -> 140193037278416 + 140193037278464 [label=UnsafeViewBackward0] + 140193037278800 -> 140193037278464 + 140193037278800 [label=CloneBackward0] + 140193037278992 -> 140193037278800 + 140193037278992 [label=ExpandBackward0] + 140193037279088 -> 140193037278992 + 140193037279088 [label=TransposeBackward0] + 140193037259728 -> 140193037279088 + 140193037259728 [label=PermuteBackward0] + 140193037260832 -> 140193037259728 + 140193037260832 [label=ViewBackward0] + 140193037261552 -> 140193037260832 + 140193037261552 [label=ViewBackward0] + 140193037262704 -> 140193037261552 + 140193037262704 [label=AddmmBackward0] + 140193037258864 -> 140193037262704 + 140193037258864 [label=ToCopyBackward0] + 140193036049952 -> 140193037258864 + 140193039536752 [label="encoder.layer.4.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039536752 -> 140193036049952 + 140193036049952 [label=AccumulateGrad] + 140193036050288 -> 140193037262704 + 140193036050288 [label=ViewBackward0] + 140193036152992 -> 140193036050288 + 140193036152992 [label=ToCopyBackward0] + 140193037276496 -> 140193036152992 + 140193036049904 -> 140193037262704 + 140193036049904 [label=TBackward0] + 140193036152896 -> 140193036049904 + 140193036152896 [label=ToCopyBackward0] + 140193036153136 -> 140193036152896 + 140193039537072 [label="encoder.layer.4.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039537072 -> 140193036153136 + 140193036153136 [label=AccumulateGrad] + 140193037277552 -> 140193037277504 + 140193037277552 [label=UnsafeViewBackward0] + 140193037277888 -> 140193037277552 + 140193037277888 [label=CloneBackward0] + 140193037278080 -> 140193037277888 + 140193037278080 [label=ExpandBackward0] + 140193037278272 -> 140193037278080 + 140193037278272 [label=PermuteBackward0] + 140193037277648 -> 140193037278272 + 140193037277648 [label=ViewBackward0] + 140193037278896 -> 140193037277648 + 140193037278896 [label=ViewBackward0] + 140193037278608 -> 140193037278896 + 140193037278608 [label=AddmmBackward0] + 140193037259104 -> 140193037278608 + 140193037259104 [label=ToCopyBackward0] + 140193036050192 -> 140193037259104 + 140193039536512 [label="encoder.layer.4.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039536512 -> 140193036050192 + 140193036050192 [label=AccumulateGrad] + 140193037261024 -> 140193037278608 + 140193037261024 [label=ViewBackward0] + 140193036153232 -> 140193037261024 + 140193036153232 [label=ToCopyBackward0] + 140193037276496 -> 140193036153232 + 140193037259152 -> 140193037278608 + 140193037259152 [label=TBackward0] + 140193036153088 -> 140193037259152 + 140193036153088 [label=ToCopyBackward0] + 140193036153280 -> 140193036153088 + 140193039536832 [label="encoder.layer.4.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039536832 -> 140193036153280 + 140193036153280 [label=AccumulateGrad] + 140193037276592 -> 140193037276784 + 140193037276592 [label=TBackward0] + 140193037277264 -> 140193037276592 + 140193037277264 [label=ToCopyBackward0] + 140193037277456 -> 140193037277264 + 140193039536592 [label="encoder.layer.4.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039536592 -> 140193037277456 + 140193037277456 [label=AccumulateGrad] + 140193037276496 -> 140193037276352 + 140193037276304 -> 140193037276256 + 140193039536352 [label="encoder.layer.4.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039536352 -> 140193037276304 + 140193037276304 [label=AccumulateGrad] + 140193037275776 -> 140193037276256 + 140193039519552 [label="encoder.layer.4.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039519552 -> 140193037275776 + 140193037275776 [label=AccumulateGrad] + 140193037275200 -> 140193037275584 + 140193037275200 [label=TBackward0] + 140193037275824 -> 140193037275200 + 140193037275824 [label=ToCopyBackward0] + 140193037276208 -> 140193037275824 + 140193039519632 [label="encoder.layer.4.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039519632 -> 140193037276208 + 140193037276208 [label=AccumulateGrad] + 140193037364960 -> 140193037364912 + 140193037364960 [label=UnsafeViewBackward0] + 140193037365056 -> 140193037364960 + 140193037365056 [label=CloneBackward0] + 140193037275440 -> 140193037365056 + 140193037275440 [label=ExpandBackward0] + 140193037275728 -> 140193037275440 + 140193037275728 [label=TransposeBackward0] + 140193037276112 -> 140193037275728 + 140193037276112 [label=PermuteBackward0] + 140193037276400 -> 140193037276112 + 140193037276400 [label=ViewBackward0] + 140193037276640 -> 140193037276400 + 140193037276640 [label=ViewBackward0] + 140193037276928 -> 140193037276640 + 140193037276928 [label=AddmmBackward0] + 140193037277360 -> 140193037276928 + 140193037277360 [label=ToCopyBackward0] + 140193037277984 -> 140193037277360 + 140193039519072 [label="encoder.layer.4.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140193039519072 -> 140193037277984 + 140193037277984 [label=AccumulateGrad] + 140193037276976 -> 140193037276928 + 140193037276976 [label=ViewBackward0] + 140193037278176 -> 140193037276976 + 140193037278176 [label=ToCopyBackward0] + 140193036037232 -> 140193037278176 + 140193037275248 -> 140193037276928 + 140193037275248 [label=TBackward0] + 140193037277024 -> 140193037275248 + 140193037277024 [label=ToCopyBackward0] + 140193037277696 -> 140193037277024 + 140193039519392 [label="encoder.layer.4.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140193039519392 -> 140193037277696 + 140193037277696 [label=AccumulateGrad] + 140193037364048 -> 140193037364000 + 140193037364048 [label=UnsafeViewBackward0] + 140193037364384 -> 140193037364048 + 140193037364384 [label=CloneBackward0] + 140193037364576 -> 140193037364384 + 140193037364576 [label=ExpandBackward0] + 140193037364768 -> 140193037364576 + 140193037364768 [label=PermuteBackward0] + 140193037364144 -> 140193037364768 + 140193037364144 [label=ViewBackward0] + 140193037261648 -> 140193037364144 + 140193037261648 [label=ViewBackward0] + 140193037364192 -> 140193037261648 + 140193037364192 [label=AddmmBackward0] + 140193037276448 -> 140193037364192 + 140193037276448 [label=ToCopyBackward0] + 140193037278704 -> 140193037276448 + 140193039518832 [label="encoder.layer.4.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140193039518832 -> 140193037278704 + 140193037278704 [label=AccumulateGrad] + 140193037276016 -> 140193037364192 + 140193037276016 [label=ViewBackward0] + 140193037277168 -> 140193037276016 + 140193037277168 [label=ToCopyBackward0] + 140193036037232 -> 140193037277168 + 140193037275344 -> 140193037364192 + 140193037275344 [label=TBackward0] + 140193037277792 -> 140193037275344 + 140193037277792 [label=ToCopyBackward0] + 140193037278368 -> 140193037277792 + 140193039519152 [label="encoder.layer.4.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140193039519152 -> 140193037278368 + 140193037278368 [label=AccumulateGrad] + 140193037363088 -> 140193037363280 + 140193037363088 [label=TBackward0] + 140193037363760 -> 140193037363088 + 140193037363760 [label=ToCopyBackward0] + 140193037363952 -> 140193037363760 + 140193039518912 [label="encoder.layer.4.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039518912 -> 140193037363952 + 140193037363952 [label=AccumulateGrad] + 140193037362992 -> 140193037362848 + 140193037362800 -> 140193037362752 + 140193039518672 [label="encoder.layer.4.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039518672 -> 140193037362800 + 140193037362800 [label=AccumulateGrad] + 140193037362368 -> 140193037362752 + 140193039518352 [label="encoder.layer.4.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039518352 -> 140193037362368 + 140193037362368 [label=AccumulateGrad] + 140193037361888 -> 140193037362176 + 140193037361888 [label=TBackward0] + 140193037362416 -> 140193037361888 + 140193037362416 [label=ToCopyBackward0] + 140193037362896 -> 140193037362416 + 140193039516912 [label="encoder.layer.4.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039516912 -> 140193037362896 + 140193037362896 [label=AccumulateGrad] + 140193037361456 -> 140193037361648 + 140193037361456 [label=TBackward0] + 140193037362128 -> 140193037361456 + 140193037362128 [label=ToCopyBackward0] + 140193037362608 -> 140193037362128 + 140193039516672 [label="encoder.layer.4.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039516672 -> 140193037362608 + 140193037362608 [label=AccumulateGrad] + 140193037361360 -> 140193037275088 + 140193037275040 -> 140193037274992 + 140193039516432 [label="encoder.layer.4.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039516432 -> 140193037275040 + 140193037275040 [label=AccumulateGrad] + 140193037361216 -> 140193037274992 + 140193039516512 [label="encoder.layer.4.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039516512 -> 140193037361216 + 140193037361216 [label=AccumulateGrad] + 140193037274752 -> 140193037272592 + 140193037274752 [label=NativeLayerNormBackward0] + 140193037361504 -> 140193037274752 + 140193037361504 [label=AddBackward0] + 140193037362320 -> 140193037361504 + 140193037362320 [label=NativeDropoutBackward0] + 140193037362032 -> 140193037362320 + 140193037362032 [label=ViewBackward0] + 140193037362560 -> 140193037362032 + 140193037362560 [label=AddmmBackward0] + 140193037363424 -> 140193037362560 + 140193037363424 [label=ToCopyBackward0] + 140193037363520 -> 140193037363424 + 140193039518192 [label="encoder.layer.4.output.dense.bias + (768)" fillcolor=lightblue] + 140193039518192 -> 140193037363520 + 140193037363520 [label=AccumulateGrad] + 140193037363232 -> 140193037362560 + 140193037363232 [label=ViewBackward0] + 140193037363664 -> 140193037363232 + 140193037363664 [label=GeluBackward0] + 140193037364672 -> 140193037363664 + 140193037364672 [label=ViewBackward0] + 140193037365104 -> 140193037364672 + 140193037365104 [label=AddmmBackward0] + 140193037364288 -> 140193037365104 + 140193037364288 [label=ToCopyBackward0] + 140193036153376 -> 140193037364288 + 140193039518432 [label="encoder.layer.4.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039518432 -> 140193036153376 + 140193036153376 [label=AccumulateGrad] + 140193037275968 -> 140193037365104 + 140193037275968 [label=ViewBackward0] + 140193036153040 -> 140193037275968 + 140193036153040 [label=ToCopyBackward0] + 140193037361840 -> 140193036153040 + 140193037361840 [label=SliceBackward0] + 140193036153472 -> 140193037361840 + 140193036153472 [label=SliceBackward0] + 140193036153568 -> 140193036153472 + 140193036153568 [label=SliceBackward0] + 140193037276256 -> 140193036153568 + 140193037275536 -> 140193037365104 + 140193037275536 [label=TBackward0] + 140193036153328 -> 140193037275536 + 140193036153328 [label=ToCopyBackward0] + 140193036153664 -> 140193036153328 + 140194225614656 [label="encoder.layer.4.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140194225614656 -> 140193036153664 + 140193036153664 [label=AccumulateGrad] + 140193037363136 -> 140193037362560 + 140193037363136 [label=TBackward0] + 140193037364864 -> 140193037363136 + 140193037364864 [label=ToCopyBackward0] + 140193037276736 -> 140193037364864 + 140193039518112 [label="encoder.layer.4.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039518112 -> 140193037276736 + 140193037276736 [label=AccumulateGrad] + 140193037361840 -> 140193037361504 + 140193037361312 -> 140193037274752 + 140193039517872 [label="encoder.layer.4.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039517872 -> 140193037361312 + 140193037361312 [label=AccumulateGrad] + 140193037361264 -> 140193037274752 + 140193039517952 [label="encoder.layer.4.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039517952 -> 140193037361264 + 140193037361264 [label=AccumulateGrad] + 140193037273936 -> 140193037274512 + 140193037273936 [label=TBackward0] + 140193037274704 -> 140193037273936 + 140193037274704 [label=ToCopyBackward0] + 140193037274896 -> 140193037274704 + 140193039516192 [label="encoder.layer.5.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039516192 -> 140193037274896 + 140193037274896 [label=AccumulateGrad] + 140193037273840 -> 140193037273792 + 140193037273840 [label=UnsafeViewBackward0] + 140193037274176 -> 140193037273840 + 140193037274176 [label=CloneBackward0] + 140193037274368 -> 140193037274176 + 140193037274368 [label=ExpandBackward0] + 140193037274656 -> 140193037274368 + 140193037274656 [label=TransposeBackward0] + 140193037274944 -> 140193037274656 + 140193037274944 [label=PermuteBackward0] + 140193037362704 -> 140193037274944 + 140193037362704 [label=ViewBackward0] + 140193037363472 -> 140193037362704 + 140193037363472 [label=ViewBackward0] + 140193037364480 -> 140193037363472 + 140193037364480 [label=AddmmBackward0] + 140193037361792 -> 140193037364480 + 140193037361792 [label=ToCopyBackward0] + 140193036153712 -> 140193037361792 + 140193039516032 [label="encoder.layer.5.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039516032 -> 140193036153712 + 140193036153712 [label=AccumulateGrad] + 140193036153520 -> 140193037364480 + 140193036153520 [label=ViewBackward0] + 140193036153760 -> 140193036153520 + 140193036153760 [label=ToCopyBackward0] + 140193037272592 -> 140193036153760 + 140193036152944 -> 140193037364480 + 140193036152944 [label=TBackward0] + 140193036153616 -> 140193036152944 + 140193036153616 [label=ToCopyBackward0] + 140193036153904 -> 140193036153616 + 140193039515952 [label="encoder.layer.5.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039515952 -> 140193036153904 + 140193036153904 [label=AccumulateGrad] + 140193037271536 -> 140193037271488 + 140193037271536 [label=UnsafeViewBackward0] + 140193037271104 -> 140193037271536 + 140193037271104 [label=CloneBackward0] + 140193037273456 -> 140193037271104 + 140193037273456 [label=ExpandBackward0] + 140193037273648 -> 140193037273456 + 140193037273648 [label=PermuteBackward0] + 140193037271440 -> 140193037273648 + 140193037271440 [label=ViewBackward0] + 140193037274272 -> 140193037271440 + 140193037274272 [label=ViewBackward0] + 140193037273984 -> 140193037274272 + 140193037273984 [label=AddmmBackward0] + 140193037271296 -> 140193037273984 + 140193037271296 [label=ToCopyBackward0] + 140193036153184 -> 140193037271296 + 140193039515792 [label="encoder.layer.5.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039515792 -> 140193036153184 + 140193036153184 [label=AccumulateGrad] + 140193037361600 -> 140193037273984 + 140193037361600 [label=ViewBackward0] + 140193036154000 -> 140193037361600 + 140193036154000 [label=ToCopyBackward0] + 140193037272592 -> 140193036154000 + 140193037362944 -> 140193037273984 + 140193037362944 [label=TBackward0] + 140193036153424 -> 140193037362944 + 140193036153424 [label=ToCopyBackward0] + 140193036154048 -> 140193036153424 + 140193039515712 [label="encoder.layer.5.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039515712 -> 140193036154048 + 140193036154048 [label=AccumulateGrad] + 140193037272496 -> 140193037272304 + 140193037272496 [label=TBackward0] + 140193037271824 -> 140193037272496 + 140193037271824 [label=ToCopyBackward0] + 140193037271632 -> 140193037271824 + 140193039498992 [label="encoder.layer.5.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039498992 -> 140193037271632 + 140193037271632 [label=AccumulateGrad] + 140193037272592 -> 140193037272640 + 140193037272784 -> 140193037272736 + 140193039498752 [label="encoder.layer.5.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039498752 -> 140193037272784 + 140193037272784 [label=AccumulateGrad] + 140193037273264 -> 140193037272736 + 140193039498832 [label="encoder.layer.5.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039498832 -> 140193037273264 + 140193037273264 [label=AccumulateGrad] + 140193036303952 -> 140193036302896 + 140193036303952 [label=TBackward0] + 140193036303472 -> 140193036303952 + 140193036303472 [label=ToCopyBackward0] + 140193037273072 -> 140193036303472 + 140193039497072 [label="encoder.layer.5.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039497072 -> 140193037273072 + 140193037273072 [label=AccumulateGrad] + 140193036301744 -> 140193036302032 + 140193036301744 [label=TBackward0] + 140193036304336 -> 140193036301744 + 140193036304336 [label=ToCopyBackward0] + 140193036302704 -> 140193036304336 + 140193039496832 [label="encoder.layer.5.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039496832 -> 140193036302704 + 140193036302704 [label=AccumulateGrad] + 140193036301552 -> 140193036301456 + 140193036301264 -> 140193036301072 + 140193039496592 [label="encoder.layer.5.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039496592 -> 140193036301264 + 140193036301264 [label=AccumulateGrad] + 140193036301360 -> 140193036301072 + 140193039496672 [label="encoder.layer.5.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039496672 -> 140193036301360 + 140193036301360 [label=AccumulateGrad] + 140193036300880 -> 140193036271824 + 140193036300880 [label=NativeLayerNormBackward0] + 140193036301840 -> 140193036300880 + 140193036301840 [label=AddBackward0] + 140193036303088 -> 140193036301840 + 140193036303088 [label=NativeDropoutBackward0] + 140193036302752 -> 140193036303088 + 140193036302752 [label=ViewBackward0] + 140193037273216 -> 140193036302752 + 140193037273216 [label=AddmmBackward0] + 140193037272544 -> 140193037273216 + 140193037272544 [label=ToCopyBackward0] + 140193037272112 -> 140193037272544 + 140193039498352 [label="encoder.layer.5.output.dense.bias + (768)" fillcolor=lightblue] + 140193039498352 -> 140193037272112 + 140193037272112 [label=AccumulateGrad] + 140193037272688 -> 140193037273216 + 140193037272688 [label=ViewBackward0] + 140193037271728 -> 140193037272688 + 140193037271728 [label=GeluBackward0] + 140193037271968 -> 140193037271728 + 140193037271968 [label=ViewBackward0] + 140193037273360 -> 140193037271968 + 140193037273360 [label=AddmmBackward0] + 140193037273744 -> 140193037273360 + 140193037273744 [label=ToCopyBackward0] + 140193037363856 -> 140193037273744 + 140193039498592 [label="encoder.layer.5.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039498592 -> 140193037363856 + 140193037363856 [label=AccumulateGrad] + 140193037273552 -> 140193037273360 + 140193037273552 [label=ViewBackward0] + 140193037274464 -> 140193037273552 + 140193037274464 [label=ToCopyBackward0] + 140193036304192 -> 140193037274464 + 140193036304192 [label=SliceBackward0] + 140193036153952 -> 140193036304192 + 140193036153952 [label=SliceBackward0] + 140193036154240 -> 140193036153952 + 140193036154240 [label=SliceBackward0] + 140193037272736 -> 140193036154240 + 140193037272064 -> 140193037273360 + 140193037272064 [label=TBackward0] + 140193036154144 -> 140193037272064 + 140193036154144 [label=ToCopyBackward0] + 140193036154336 -> 140193036154144 + 140193039498512 [label="encoder.layer.5.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039498512 -> 140193036154336 + 140193036154336 [label=AccumulateGrad] + 140193037272880 -> 140193037273216 + 140193037272880 [label=TBackward0] + 140193037271920 -> 140193037272880 + 140193037271920 [label=ToCopyBackward0] + 140193037274080 -> 140193037271920 + 140193039498272 [label="encoder.layer.5.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039498272 -> 140193037274080 + 140193037274080 [label=AccumulateGrad] + 140193036304192 -> 140193036301840 + 140193036302800 -> 140193036300880 + 140193039498032 [label="encoder.layer.5.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039498032 -> 140193036302800 + 140193036302800 [label=AccumulateGrad] + 140193036301312 -> 140193036300880 + 140193039498112 [label="encoder.layer.5.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039498112 -> 140193036301312 + 140193036301312 [label=AccumulateGrad] + 140193036304048 -> 140193036300400 + 140193036304048 [label=TBackward0] + 140193036300592 -> 140193036304048 + 140193036300592 [label=ToCopyBackward0] + 140193036302320 -> 140193036300592 + 140193039496352 [label="encoder.layer.6.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039496352 -> 140193036302320 + 140193036302320 [label=AccumulateGrad] + 140193036275184 -> 140193036275280 + 140193036275184 [label=UnsafeViewBackward0] + 140193036275568 -> 140193036275184 + 140193036275568 [label=CloneBackward0] + 140193036302512 -> 140193036275568 + 140193036302512 [label=ExpandBackward0] + 140193036300688 -> 140193036302512 + 140193036300688 [label=TransposeBackward0] + 140193036301936 -> 140193036300688 + 140193036301936 [label=PermuteBackward0] + 140193036300832 -> 140193036301936 + 140193036300832 [label=ViewBackward0] + 140193037272352 -> 140193036300832 + 140193037272352 [label=ViewBackward0] + 140193037271200 -> 140193037272352 + 140193037271200 [label=AddmmBackward0] + 140193037272976 -> 140193037271200 + 140193037272976 [label=ToCopyBackward0] + 140193036154384 -> 140193037272976 + 140193039496192 [label="encoder.layer.6.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039496192 -> 140193036154384 + 140193036154384 [label=AccumulateGrad] + 140193036154192 -> 140193037271200 + 140193036154192 [label=ViewBackward0] + 140193036154432 -> 140193036154192 + 140193036154432 [label=ToCopyBackward0] + 140193036271824 -> 140193036154432 + 140193036153808 -> 140193037271200 + 140193036153808 [label=TBackward0] + 140193036154288 -> 140193036153808 + 140193036154288 [label=ToCopyBackward0] + 140193036154576 -> 140193036154288 + 140193039496112 [label="encoder.layer.6.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039496112 -> 140193036154576 + 140193036154576 [label=AccumulateGrad] + 140193036273552 -> 140193036273648 + 140193036273552 [label=UnsafeViewBackward0] + 140193036274320 -> 140193036273552 + 140193036274320 [label=CloneBackward0] + 140193036274512 -> 140193036274320 + 140193036274512 [label=ExpandBackward0] + 140193036274752 -> 140193036274512 + 140193036274752 [label=PermuteBackward0] + 140193036274416 -> 140193036274752 + 140193036274416 [label=ViewBackward0] + 140193036273936 -> 140193036274416 + 140193036273936 [label=ViewBackward0] + 140193036301168 -> 140193036273936 + 140193036301168 [label=AddmmBackward0] + 140193036303376 -> 140193036301168 + 140193036303376 [label=ToCopyBackward0] + 140193036153856 -> 140193036303376 + 140193039495952 [label="encoder.layer.6.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039495952 -> 140193036153856 + 140193036153856 [label=AccumulateGrad] + 140193036303856 -> 140193036301168 + 140193036303856 [label=ViewBackward0] + 140193036154672 -> 140193036303856 + 140193036154672 [label=ToCopyBackward0] + 140193036271824 -> 140193036154672 + 140193037273168 -> 140193036301168 + 140193037273168 [label=TBackward0] + 140193036154096 -> 140193037273168 + 140193036154096 [label=ToCopyBackward0] + 140193036154720 -> 140193036154096 + 140193039495872 [label="encoder.layer.6.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039495872 -> 140193036154720 + 140193036154720 [label=AccumulateGrad] + 140193036271872 -> 140193036272304 + 140193036271872 [label=TBackward0] + 140193036273072 -> 140193036271872 + 140193036273072 [label=ToCopyBackward0] + 140193036273312 -> 140193036273072 + 140193039495632 [label="encoder.layer.6.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039495632 -> 140193036273312 + 140193036273312 [label=AccumulateGrad] + 140193036271824 -> 140193036250848 + 140193036250224 -> 140193036250992 + 140193039495392 [label="encoder.layer.6.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039495392 -> 140193036250224 + 140193036250224 [label=AccumulateGrad] + 140193036271680 -> 140193036250992 + 140193039495472 [label="encoder.layer.6.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039495472 -> 140193036271680 + 140193036271680 [label=AccumulateGrad] + 140193036248976 -> 140193036249936 + 140193036248976 [label=TBackward0] + 140193036250128 -> 140193036248976 + 140193036250128 [label=ToCopyBackward0] + 140193036250800 -> 140193036250128 + 140193039495232 [label="encoder.layer.6.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039495232 -> 140193036250800 + 140193036250800 [label=AccumulateGrad] + 140193036248784 -> 140193036248448 + 140193036248784 [label=UnsafeViewBackward0] + 140193036249168 -> 140193036248784 + 140193036249168 [label=CloneBackward0] + 140193036249408 -> 140193036249168 + 140193036249408 [label=ExpandBackward0] + 140193036249888 -> 140193036249408 + 140193036249888 [label=TransposeBackward0] + 140193036250608 -> 140193036249888 + 140193036250608 [label=PermuteBackward0] + 140193036250512 -> 140193036250608 + 140193036250512 [label=ViewBackward0] + 140193036272208 -> 140193036250512 + 140193036272208 [label=ViewBackward0] + 140193036272688 -> 140193036272208 + 140193036272688 [label=AddmmBackward0] + 140193036273264 -> 140193036272688 + 140193036273264 [label=ToCopyBackward0] + 140193036274272 -> 140193036273264 + 140193039490800 [label="encoder.layer.6.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140193039490800 -> 140193036274272 + 140193036274272 [label=AccumulateGrad] + 140193036272592 -> 140193036272688 + 140193036272592 [label=ViewBackward0] + 140193036274704 -> 140193036272592 + 140193036274704 [label=ToCopyBackward0] + 140193036037232 -> 140193036274704 + 140193036271920 -> 140193036272688 + 140193036271920 [label=TBackward0] + 140193036272832 -> 140193036271920 + 140193036272832 [label=ToCopyBackward0] + 140193036300352 -> 140193036272832 + 140193039490720 [label="encoder.layer.6.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140193039490720 -> 140193036300352 + 140193036300352 [label=AccumulateGrad] + 140193036247152 -> 140193036214128 + 140193036247152 [label=UnsafeViewBackward0] + 140193036247824 -> 140193036247152 + 140193036247824 [label=CloneBackward0] + 140193036248112 -> 140193036247824 + 140193036248112 [label=ExpandBackward0] + 140193036248496 -> 140193036248112 + 140193036248496 [label=PermuteBackward0] + 140193036247248 -> 140193036248496 + 140193036247248 [label=ViewBackward0] + 140193036249360 -> 140193036247248 + 140193036249360 [label=ViewBackward0] + 140193036250368 -> 140193036249360 + 140193036250368 [label=AddmmBackward0] + 140193036303232 -> 140193036250368 + 140193036303232 [label=ToCopyBackward0] + 140193036274128 -> 140193036303232 + 140193039490560 [label="encoder.layer.6.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140193039490560 -> 140193036274128 + 140193036274128 [label=AccumulateGrad] + 140193036248880 -> 140193036250368 + 140193036248880 [label=ViewBackward0] + 140193036272880 -> 140193036248880 + 140193036272880 [label=ToCopyBackward0] + 140193036037232 -> 140193036272880 + 140193036247536 -> 140193036250368 + 140193036247536 [label=TBackward0] + 140193036272400 -> 140193036247536 + 140193036272400 [label=ToCopyBackward0] + 140193036274992 -> 140193036272400 + 140193039490480 [label="encoder.layer.6.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140193039490480 -> 140193036274992 + 140193036274992 [label=AccumulateGrad] + 140193036212688 -> 140193036212976 + 140193036212688 [label=TBackward0] + 140193036213696 -> 140193036212688 + 140193036213696 [label=ToCopyBackward0] + 140193036213552 -> 140193036213696 + 140193039490240 [label="encoder.layer.6.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039490240 -> 140193036213552 + 140193036213552 [label=AccumulateGrad] + 140193036212496 -> 140193036212400 + 140193036212208 -> 140193036212304 + 140193039490000 [label="encoder.layer.6.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039490000 -> 140193036212208 + 140193036212208 [label=AccumulateGrad] + 140193036211632 -> 140193036212304 + 140193039490080 [label="encoder.layer.6.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039490080 -> 140193036211632 + 140193036211632 [label=AccumulateGrad] + 140193036210576 -> 140193036211056 + 140193036210576 [label=TBackward0] + 140193036211776 -> 140193036210576 + 140193036211776 [label=ToCopyBackward0] + 140193036212256 -> 140193036211776 + 140193039487120 [label="encoder.layer.6.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039487120 -> 140193036212256 + 140193036212256 [label=AccumulateGrad] + 140193036210384 -> 140193036713840 + 140193036210384 [label=TBackward0] + 140193036211152 -> 140193036210384 + 140193036211152 [label=ToCopyBackward0] + 140193036211296 -> 140193036211152 + 140193039487440 [label="encoder.layer.6.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039487440 -> 140193036211296 + 140193036211296 [label=AccumulateGrad] + 140193036713456 -> 140193036713120 + 140193036713456 [label=UnsqueezeBackward0] + 140193036713648 -> 140193036713456 + 140193036713648 [label=NativeDropoutBackward0] + 140193036713552 -> 140193036713648 + 140193036713552 [label=ViewBackward0] + 140193036212880 -> 140193036713552 + 140193036212880 [label=AddmmBackward0] + 140193036210960 -> 140193036212880 + 140193036210960 [label=ToCopyBackward0] + 140193036213360 -> 140193036210960 + 140193039469680 [label="encoder.layer.6.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140193039469680 -> 140193036213360 + 140193036213360 [label=AccumulateGrad] + 140193036212016 -> 140193036212880 + 140193036212016 [label=ViewBackward0] + 140193036213216 -> 140193036212016 + 140193036213216 [label=GeluBackward0] + 140193036213648 -> 140193036213216 + 140193036213648 [label=ViewBackward0] + 140193036213072 -> 140193036213648 + 140193036213072 [label=AddmmBackward0] + 140193036248304 -> 140193036213072 + 140193036248304 [label=ToCopyBackward0] + 140193037272256 -> 140193036248304 + 140193039469920 [label="encoder.layer.6.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140193039469920 -> 140193037272256 + 140193037272256 [label=AccumulateGrad] + 140193036248016 -> 140193036213072 + 140193036248016 [label=ViewBackward0] + 140193036248928 -> 140193036248016 + 140193036248928 [label=ToCopyBackward0] + 140193036211920 -> 140193036248928 + 140193036247632 -> 140193036213072 + 140193036247632 [label=TBackward0] + 140193036271728 -> 140193036247632 + 140193036271728 [label=ToCopyBackward0] + 140193036154816 -> 140193036271728 + 140193039470240 [label="encoder.layer.6.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039470240 -> 140193036154816 + 140193036154816 [label=AccumulateGrad] + 140193036210336 -> 140193036212880 + 140193036210336 [label=TBackward0] + 140193036211728 -> 140193036210336 + 140193036211728 [label=ToCopyBackward0] + 140193036249648 -> 140193036211728 + 140193039470000 [label="encoder.layer.6.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039470000 -> 140193036249648 + 140193036249648 [label=AccumulateGrad] + 140193036712112 -> 140193036712208 + 140193036712112 [label=UnsqueezeBackward0] + 140193036712784 -> 140193036712112 + 140193036712784 [label=UnsqueezeBackward0] + 140193036713168 -> 140193036712784 + 140193036713168 [label=MulBackward0] + 140193036712160 -> 140193036713168 + 140193036712160 [label=IndexBackward0] + 140193036713600 -> 140193036712160 + 140193036713600 [label=ViewBackward0] + 140193036212592 -> 140193036713600 + 140193036212592 [label=CloneBackward0] + 140193036213936 -> 140193036212592 + 140193036213936 [label=ExpandBackward0] + 140193036247104 -> 140193036213936 + 140193036247104 [label=UnsqueezeBackward0] + 140193036154480 -> 140193036247104 + 140193036154480 [label=SoftmaxBackward0] + 140193036154864 -> 140193036154480 + 140193036154864 [label=CatBackward0] + 140193036154960 -> 140193036154864 + 140193036154960 [label=MmBackward0] + 140193036155104 -> 140193036154960 + 140193036155104 [label=MeanBackward1] + 140193036713744 -> 140193036155104 + 140193036155056 -> 140193036154960 + 140193036155056 [label=TBackward0] + 140193036155152 -> 140193036155056 + 140193036155152 [label=ToCopyBackward0] + 140193036155344 -> 140193036155152 + 140193039487840 [label="encoder.layer.6.experts.gate.weight + (1, 768)" fillcolor=lightblue] + 140193039487840 -> 140193036155344 + 140193036155344 [label=AccumulateGrad] + 140193036154912 -> 140193036154864 + 140193036154912 [label=MmBackward0] + 140193036155296 -> 140193036154912 + 140193036155296 [label=MeanBackward1] + 140193036713648 -> 140193036155296 + 140193036155200 -> 140193036154912 + 140193036155200 [label=TBackward0] + 140193036155152 -> 140193036155200 + 140193036711344 -> 140193036711056 + 140193036711344 [label=UnsqueezeBackward0] + 140193036711680 -> 140193036711344 + 140193036711680 [label=SelectBackward0] + 140193036711536 -> 140193036711680 + 140193036711536 [label=NativeDropoutBackward0] + 140193036712976 -> 140193036711536 + 140193036712976 [label=ViewBackward0] + 140193036712496 -> 140193036712976 + 140193036712496 [label=AddmmBackward0] + 140193036210672 -> 140193036712496 + 140193036210672 [label=ToCopyBackward0] + 140193036154624 -> 140193036210672 + 140193039469200 [label="encoder.layer.6.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140193039469200 -> 140193036154624 + 140193036154624 [label=AccumulateGrad] + 140193036211440 -> 140193036712496 + 140193036211440 [label=ViewBackward0] + 140193036155008 -> 140193036211440 + 140193036155008 [label=GeluBackward0] + 140193036155392 -> 140193036155008 + 140193036155392 [label=ViewBackward0] + 140193036155488 -> 140193036155392 + 140193036155488 [label=AddmmBackward0] + 140193036155584 -> 140193036155488 + 140193036155584 [label=ToCopyBackward0] + 140193036155776 -> 140193036155584 + 140193039469440 [label="encoder.layer.6.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140193039469440 -> 140193036155776 + 140193036155776 [label=AccumulateGrad] + 140193036155536 -> 140193036155488 + 140193036155536 [label=ViewBackward0] + 140193036155824 -> 140193036155536 + 140193036155824 [label=ToCopyBackward0] + 140193036211920 -> 140193036155824 + 140193036154768 -> 140193036155488 + 140193036154768 [label=TBackward0] + 140193036155680 -> 140193036154768 + 140193036155680 [label=ToCopyBackward0] + 140193036155968 -> 140193036155680 + 140193039469760 [label="encoder.layer.6.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039469760 -> 140193036155968 + 140193036155968 [label=AccumulateGrad] + 140193036711440 -> 140193036712496 + 140193036711440 [label=TBackward0] + 140193036155248 -> 140193036711440 + 140193036155248 [label=ToCopyBackward0] + 140193036155920 -> 140193036155248 + 140193039469520 [label="encoder.layer.6.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039469520 -> 140193036155920 + 140193036155920 [label=AccumulateGrad] + 140193036710720 -> 140193036710672 + 140193036710720 [label=CatBackward0] + 140193036711920 -> 140193036710720 + 140193036711920 [label=SliceBackward0] + 140193036248592 -> 140193036711920 + 140193036248592 [label=SliceBackward0] + 140193036712688 -> 140193036248592 + 140193036712688 [label=SliceBackward0] + 140193036712016 -> 140193036712688 + 140193036711632 -> 140193036710720 + 140193036711632 [label=UnsqueezeBackward0] + 140193036713264 -> 140193036711632 + 140193036713264 [label=SelectBackward0] + 140193036711536 -> 140193036713264 + 140193036710864 -> 140193036710672 + 140193036710864 [label=CatBackward0] + 140193036711152 -> 140193036710864 + 140193036711152 [label=SliceBackward0] + 140193036155440 -> 140193036711152 + 140193036155440 [label=SliceBackward0] + 140193036155872 -> 140193036155440 + 140193036155872 [label=SliceBackward0] + 140193036712016 -> 140193036155872 + 140193036154528 -> 140193036710864 + 140193036154528 [label=UnsqueezeBackward0] + 140193036156160 -> 140193036154528 + 140193036156160 [label=SelectBackward0] + 140193036711536 -> 140193036156160 + 140193036710960 -> 140193036710672 + 140193036710960 [label=CatBackward0] + 140193036156064 -> 140193036710960 + 140193036156064 [label=SliceBackward0] + 140193036156208 -> 140193036156064 + 140193036156208 [label=SliceBackward0] + 140193036156304 -> 140193036156208 + 140193036156304 [label=SliceBackward0] + 140193036712016 -> 140193036156304 + 140193036156016 -> 140193036710960 + 140193036156016 [label=UnsqueezeBackward0] + 140193036156400 -> 140193036156016 + 140193036156400 [label=SelectBackward0] + 140193036711536 -> 140193036156400 + 140193036710768 -> 140193036710240 + 140193036710768 [label=ViewBackward0] + 140193036711248 -> 140193036710768 + 140193036711248 [label=CloneBackward0] + 140193036156352 -> 140193036711248 + 140193036156352 [label=ExpandBackward0] + 140193036156448 -> 140193036156352 + 140193036156448 [label=UnsqueezeBackward0] + 140193036211920 -> 140193036156448 + 140193036710384 -> 140193036710288 + 140194225780112 [label="encoder.layer.6.expert_ln.weight + (768)" fillcolor=lightblue] + 140194225780112 -> 140193036710384 + 140193036710384 [label=AccumulateGrad] + 140193036710192 -> 140193036710288 + 140193039487360 [label="encoder.layer.6.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039487360 -> 140193036710192 + 140193036710192 [label=AccumulateGrad] + 140193036710096 -> 140193036659728 + 140193036710096 [label=ViewBackward0] + 140193036710480 -> 140193036710096 + 140193036710480 [label=CloneBackward0] + 140193036155632 -> 140193036710480 + 140193036155632 [label=ExpandBackward0] + 140193036156496 -> 140193036155632 + 140193036156496 [label=UnsqueezeBackward0] + 140193036156592 -> 140193036156496 + 140193036156592 [label=NativeLayerNormBackward0] + 140193036156688 -> 140193036156592 + 140193036156688 [label=AddBackward0] + 140193036156880 -> 140193036156688 + 140193036156880 [label=NativeDropoutBackward0] + 140201394335904 -> 140193036156880 + 140201394335904 [label=ViewBackward0] + 140201394336000 -> 140201394335904 + 140201394336000 [label=AddmmBackward0] + 140201394336096 -> 140201394336000 + 140201394336096 [label=ToCopyBackward0] + 140201394336288 -> 140201394336096 + 140193039489600 [label="encoder.layer.6.output.dense.bias + (768)" fillcolor=lightblue] + 140193039489600 -> 140201394336288 + 140201394336288 [label=AccumulateGrad] + 140201394336048 -> 140201394336000 + 140201394336048 [label=ViewBackward0] + 140201394336336 -> 140201394336048 + 140201394336336 [label=GeluBackward0] + 140201394336432 -> 140201394336336 + 140201394336432 [label=ViewBackward0] + 140201394336528 -> 140201394336432 + 140201394336528 [label=AddmmBackward0] + 140201394336624 -> 140201394336528 + 140201394336624 [label=ToCopyBackward0] + 140201394336816 -> 140201394336624 + 140193039489840 [label="encoder.layer.6.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039489840 -> 140201394336816 + 140201394336816 [label=AccumulateGrad] + 140201394336576 -> 140201394336528 + 140201394336576 [label=ViewBackward0] + 140201394336864 -> 140201394336576 + 140201394336864 [label=ToCopyBackward0] + 140193036156832 -> 140201394336864 + 140193036156832 [label=SliceBackward0] + 140201394337008 -> 140193036156832 + 140201394337008 [label=SliceBackward0] + 140201394337104 -> 140201394337008 + 140201394337104 [label=SliceBackward0] + 140193036250992 -> 140201394337104 + 140201394336240 -> 140201394336528 + 140201394336240 [label=TBackward0] + 140201394336768 -> 140201394336240 + 140201394336768 [label=ToCopyBackward0] + 140201394337200 -> 140201394336768 + 140193039489760 [label="encoder.layer.6.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039489760 -> 140201394337200 + 140201394337200 [label=AccumulateGrad] + 140201394335808 -> 140201394336000 + 140201394335808 [label=TBackward0] + 140201394336480 -> 140201394335808 + 140201394336480 [label=ToCopyBackward0] + 140201394336960 -> 140201394336480 + 140193039489520 [label="encoder.layer.6.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039489520 -> 140201394336960 + 140201394336960 [label=AccumulateGrad] + 140193036156832 -> 140193036156688 + 140193036156640 -> 140193036156592 + 140193039489280 [label="encoder.layer.6.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039489280 -> 140193036156640 + 140193036156640 [label=AccumulateGrad] + 140193036156256 -> 140193036156592 + 140193039489360 [label="encoder.layer.6.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039489360 -> 140193036156256 + 140193036156256 [label=AccumulateGrad] + 140193036679728 -> 140193036680688 + 140193036679728 [label=TBackward0] + 140193036680976 -> 140193036679728 + 140193036680976 [label=ToCopyBackward0] + 140193036710576 -> 140193036680976 + 140193039488160 [label="encoder.layer.7.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039488160 -> 140193036710576 + 140193036710576 [label=AccumulateGrad] + 140193036679632 -> 140193036679344 + 140193036679632 [label=UnsafeViewBackward0] + 140193036680016 -> 140193036679632 + 140193036680016 [label=CloneBackward0] + 140193036680304 -> 140193036680016 + 140193036680304 [label=ExpandBackward0] + 140193036680784 -> 140193036680304 + 140193036680784 [label=TransposeBackward0] + 140193036710000 -> 140193036680784 + 140193036710000 [label=PermuteBackward0] + 140193036679584 -> 140193036710000 + 140193036679584 [label=ViewBackward0] + 140193036156736 -> 140193036679584 + 140193036156736 [label=ViewBackward0] + 140193036156784 -> 140193036156736 + 140193036156784 [label=AddmmBackward0] + 140201394336144 -> 140193036156784 + 140201394336144 [label=ToCopyBackward0] + 140201394337056 -> 140201394336144 + 140193039488320 [label="encoder.layer.7.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039488320 -> 140201394337056 + 140201394337056 [label=AccumulateGrad] + 140201394335952 -> 140193036156784 + 140201394335952 [label=ViewBackward0] + 140201394336384 -> 140201394335952 + 140201394336384 [label=ToCopyBackward0] + 140193036659728 -> 140201394336384 + 140201394335856 -> 140193036156784 + 140201394335856 [label=TBackward0] + 140201394336672 -> 140201394335856 + 140201394336672 [label=ToCopyBackward0] + 140201394337248 -> 140201394336672 + 140193039488400 [label="encoder.layer.7.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039488400 -> 140201394337248 + 140201394337248 [label=AccumulateGrad] + 140193036677904 -> 140193036678000 + 140193036677904 [label=UnsafeViewBackward0] + 140193036678672 -> 140193036677904 + 140193036678672 [label=CloneBackward0] + 140193036678960 -> 140193036678672 + 140193036678960 [label=ExpandBackward0] + 140193036679248 -> 140193036678960 + 140193036679248 [label=PermuteBackward0] + 140193036678096 -> 140193036679248 + 140193036678096 [label=ViewBackward0] + 140193036680064 -> 140193036678096 + 140193036680064 [label=ViewBackward0] + 140193036681024 -> 140193036680064 + 140193036681024 [label=AddmmBackward0] + 140193036678288 -> 140193036681024 + 140193036678288 [label=ToCopyBackward0] + 140201394336720 -> 140193036678288 + 140193039487920 [label="encoder.layer.7.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039487920 -> 140201394336720 + 140201394336720 [label=AccumulateGrad] + 140193036156112 -> 140193036681024 + 140193036156112 [label=ViewBackward0] + 140201394337344 -> 140193036156112 + 140201394337344 [label=ToCopyBackward0] + 140193036659728 -> 140201394337344 + 140193036156544 -> 140193036681024 + 140193036156544 [label=TBackward0] + 140201394336192 -> 140193036156544 + 140201394336192 [label=ToCopyBackward0] + 140201394337392 -> 140201394336192 + 140193039487600 [label="encoder.layer.7.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039487600 -> 140201394337392 + 140201394337392 [label=AccumulateGrad] + 140193036659776 -> 140193036660208 + 140193036659776 [label=TBackward0] + 140193036677424 -> 140193036659776 + 140193036677424 [label=ToCopyBackward0] + 140193036677664 -> 140193036677424 + 140193039487680 [label="encoder.layer.7.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039487680 -> 140193036677664 + 140193036677664 [label=AccumulateGrad] + 140193036659728 -> 140193036659632 + 140193036659296 -> 140193036659440 + 140193039468960 [label="encoder.layer.7.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039468960 -> 140193036659296 + 140193036659296 [label=AccumulateGrad] + 140193036658672 -> 140193036659440 + 140193039470160 [label="encoder.layer.7.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039470160 -> 140193036658672 + 140193036658672 [label=AccumulateGrad] + 140193036657424 -> 140193036657904 + 140193036657424 [label=TBackward0] + 140193036658576 -> 140193036657424 + 140193036658576 [label=ToCopyBackward0] + 140193036658768 -> 140193036658576 + 140193039461968 [label="encoder.layer.7.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039461968 -> 140193036658768 + 140193036658768 [label=AccumulateGrad] + 140193036656704 -> 140193036656848 + 140193036656704 [label=TBackward0] + 140193036657616 -> 140193036656704 + 140193036657616 [label=ToCopyBackward0] + 140193036658192 -> 140193036657616 + 140193039461648 [label="encoder.layer.7.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039461648 -> 140193036658192 + 140193036658192 [label=AccumulateGrad] + 140193036631536 -> 140193036631632 + 140193036631536 [label=UnsqueezeBackward0] + 140193036631920 -> 140193036631536 + 140193036631920 [label=NativeDropoutBackward0] + 140193036657136 -> 140193036631920 + 140193036657136 [label=ViewBackward0] + 140193036659056 -> 140193036657136 + 140193036659056 [label=AddmmBackward0] + 140193036657376 -> 140193036659056 + 140193036657376 [label=ToCopyBackward0] + 140193036659536 -> 140193036657376 + 140193039461728 [label="encoder.layer.7.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140193039461728 -> 140193036659536 + 140193036659536 [label=AccumulateGrad] + 140193036658096 -> 140193036659056 + 140193036658096 [label=ViewBackward0] + 140193036659824 -> 140193036658096 + 140193036659824 [label=GeluBackward0] + 140193036660304 -> 140193036659824 + 140193036660304 [label=ViewBackward0] + 140193036660592 -> 140193036660304 + 140193036660592 [label=AddmmBackward0] + 140193036678480 -> 140193036660592 + 140193036678480 [label=ToCopyBackward0] + 140193036679152 -> 140193036678480 + 140193039462128 [label="encoder.layer.7.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140193039462128 -> 140193036679152 + 140193036679152 [label=AccumulateGrad] + 140193036677616 -> 140193036660592 + 140193036677616 [label=ViewBackward0] + 140193036679440 -> 140193036677616 + 140193036679440 [label=ToCopyBackward0] + 140193036658336 -> 140193036679440 + 140193036677184 -> 140193036660592 + 140193036677184 [label=TBackward0] + 140193036677328 -> 140193036677184 + 140193036677328 [label=ToCopyBackward0] + 140193036155728 -> 140193036677328 + 140193039461408 [label="encoder.layer.7.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039461408 -> 140193036155728 + 140193036155728 [label=AccumulateGrad] + 140193036656752 -> 140193036659056 + 140193036656752 [label=TBackward0] + 140193036659248 -> 140193036656752 + 140193036659248 [label=ToCopyBackward0] + 140193036660112 -> 140193036659248 + 140193039461168 [label="encoder.layer.7.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039461168 -> 140193036660112 + 140193036660112 [label=AccumulateGrad] + 140193036631152 -> 140193036630816 + 140193036631152 [label=UnsqueezeBackward0] + 140193036631344 -> 140193036631152 + 140193036631344 [label=UnsqueezeBackward0] + 140193036631248 -> 140193036631344 + 140193036631248 [label=MulBackward0] + 140193036658816 -> 140193036631248 + 140193036658816 [label=IndexBackward0] + 140193036658480 -> 140193036658816 + 140193036658480 [label=SoftmaxBackward0] + 140193036678768 -> 140193036658480 + 140193036678768 [label=CatBackward0] + 140193036677232 -> 140193036678768 + 140193036677232 [label=MmBackward0] + 140201394337440 -> 140193036677232 + 140201394337440 [label=MeanBackward1] + 140193036631776 -> 140201394337440 + 140201394337488 -> 140193036677232 + 140201394337488 [label=TBackward0] + 140201394337296 -> 140201394337488 + 140201394337296 [label=ToCopyBackward0] + 140201394337680 -> 140201394337296 + 140193039467120 [label="encoder.layer.7.experts.gate.weight + (1, 768)" fillcolor=lightblue] + 140193039467120 -> 140201394337680 + 140201394337680 [label=AccumulateGrad] + 140193036680496 -> 140193036678768 + 140193036680496 [label=MmBackward0] + 140201394337632 -> 140193036680496 + 140201394337632 [label=MeanBackward1] + 140193036631920 -> 140201394337632 + 140201394337536 -> 140193036680496 + 140201394337536 [label=TBackward0] + 140201394337296 -> 140201394337536 + 140193036630096 -> 140193036629808 + 140193036630096 [label=UnsqueezeBackward0] + 140193036630864 -> 140193036630096 + 140193036630864 [label=SelectBackward0] + 140193036630288 -> 140193036630864 + 140193036630288 [label=NativeDropoutBackward0] + 140193036631296 -> 140193036630288 + 140193036631296 [label=ViewBackward0] + 140193036656944 -> 140193036631296 + 140193036656944 [label=AddmmBackward0] + 140193036679824 -> 140193036656944 + 140193036679824 [label=ToCopyBackward0] + 140201394337584 -> 140193036679824 + 140193039461248 [label="encoder.layer.7.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140193039461248 -> 140201394337584 + 140201394337584 [label=AccumulateGrad] + 140193036657232 -> 140193036656944 + 140193036657232 [label=ViewBackward0] + 140201394337824 -> 140193036657232 + 140201394337824 [label=GeluBackward0] + 140201394337920 -> 140201394337824 + 140201394337920 [label=ViewBackward0] + 140201394338016 -> 140201394337920 + 140201394338016 [label=AddmmBackward0] + 140201394338112 -> 140201394338016 + 140201394338112 [label=ToCopyBackward0] + 140201394338304 -> 140201394338112 + 140193039461488 [label="encoder.layer.7.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140193039461488 -> 140201394338304 + 140201394338304 [label=AccumulateGrad] + 140201394338064 -> 140201394338016 + 140201394338064 [label=ViewBackward0] + 140201394338352 -> 140201394338064 + 140201394338352 [label=ToCopyBackward0] + 140201394338448 -> 140201394338352 + 140201394338448 [label=IndexBackward0] + 140193036629376 -> 140201394338448 + 140201394337728 -> 140201394338016 + 140201394337728 [label=TBackward0] + 140201394338544 -> 140201394337728 + 140201394338544 [label=ToCopyBackward0] + 140201394338256 -> 140201394338544 + 140193039460928 [label="encoder.layer.7.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039460928 -> 140201394338256 + 140201394338256 [label=AccumulateGrad] + 140201394336912 -> 140193036656944 + 140201394336912 [label=TBackward0] + 140201394337968 -> 140201394336912 + 140201394337968 [label=ToCopyBackward0] + 140201394338208 -> 140201394337968 + 140193039460688 [label="encoder.layer.7.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039460688 -> 140201394338208 + 140201394338208 [label=AccumulateGrad] + 140193036629904 -> 140193036629712 + 140193036629904 [label=CatBackward0] + 140193036630960 -> 140193036629904 + 140193036630960 [label=SliceBackward0] + 140193036657856 -> 140193036630960 + 140193036657856 [label=SliceBackward0] + 140201394338160 -> 140193036657856 + 140201394338160 [label=SliceBackward0] + 140193036630768 -> 140201394338160 + 140193036630672 -> 140193036629904 + 140193036630672 [label=UnsqueezeBackward0] + 140193036630480 -> 140193036630672 + 140193036630480 [label=SelectBackward0] + 140193036630288 -> 140193036630480 + 140193036629616 -> 140193036629712 + 140193036629616 [label=CatBackward0] + 140193036630192 -> 140193036629616 + 140193036630192 [label=SliceBackward0] + 140201394337872 -> 140193036630192 + 140201394337872 [label=SliceBackward0] + 140201394338400 -> 140201394337872 + 140201394338400 [label=SliceBackward0] + 140193036630768 -> 140201394338400 + 140201394337152 -> 140193036629616 + 140201394337152 [label=UnsqueezeBackward0] + 140201394338736 -> 140201394337152 + 140201394338736 [label=SelectBackward0] + 140193036630288 -> 140201394338736 + 140193036630000 -> 140193036629712 + 140193036630000 [label=CatBackward0] + 140201394338640 -> 140193036630000 + 140201394338640 [label=SliceBackward0] + 140201394338784 -> 140201394338640 + 140201394338784 [label=SliceBackward0] + 140201394338880 -> 140201394338784 + 140201394338880 [label=SliceBackward0] + 140193036630768 -> 140201394338880 + 140201394338592 -> 140193036630000 + 140201394338592 [label=UnsqueezeBackward0] + 140201394338976 -> 140201394338592 + 140201394338976 [label=SelectBackward0] + 140193036630288 -> 140201394338976 + 140193036629376 -> 140193036629424 + 140193036629136 -> 140193036628896 + 140193039467040 [label="encoder.layer.7.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039467040 -> 140193036629136 + 140193036629136 [label=AccumulateGrad] + 140193036629232 -> 140193036628896 + 140193039466800 [label="encoder.layer.7.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039466800 -> 140193036629232 + 140193036629232 [label=AccumulateGrad] + 140193036628752 -> 140193036570576 + 140193036628752 [label=IndexBackward0] + 140193036629856 -> 140193036628752 + 140193036629856 [label=NativeLayerNormBackward0] + 140193036629328 -> 140193036629856 + 140193036629328 [label=AddBackward0] + 140201394339024 -> 140193036629328 + 140201394339024 [label=NativeDropoutBackward0] + 140201394339168 -> 140201394339024 + 140201394339168 [label=ViewBackward0] + 140201394339264 -> 140201394339168 + 140201394339264 [label=AddmmBackward0] + 140201394339360 -> 140201394339264 + 140201394339360 [label=ToCopyBackward0] + 140201394339552 -> 140201394339360 + 140193039468800 [label="encoder.layer.7.output.dense.bias + (768)" fillcolor=lightblue] + 140193039468800 -> 140201394339552 + 140201394339552 [label=AccumulateGrad] + 140201394339312 -> 140201394339264 + 140201394339312 [label=ViewBackward0] + 140201394339600 -> 140201394339312 + 140201394339600 [label=GeluBackward0] + 140201394339696 -> 140201394339600 + 140201394339696 [label=ViewBackward0] + 140201394339792 -> 140201394339696 + 140201394339792 [label=AddmmBackward0] + 140201394339504 -> 140201394339792 + 140201394339504 [label=ToCopyBackward0] + 140201394377008 -> 140201394339504 + 140193039469040 [label="encoder.layer.7.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039469040 -> 140201394377008 + 140201394377008 [label=AccumulateGrad] + 140201394376816 -> 140201394339792 + 140201394376816 [label=ViewBackward0] + 140201394377056 -> 140201394376816 + 140201394377056 [label=ToCopyBackward0] + 140201394337776 -> 140201394377056 + 140201394337776 [label=SliceBackward0] + 140201394377200 -> 140201394337776 + 140201394377200 [label=SliceBackward0] + 140201394377296 -> 140201394377200 + 140201394377296 [label=SliceBackward0] + 140193036659440 -> 140201394377296 + 140201394376768 -> 140201394339792 + 140201394376768 [label=TBackward0] + 140201394376960 -> 140201394376768 + 140201394376960 [label=ToCopyBackward0] + 140201394377392 -> 140201394376960 + 140193039469280 [label="encoder.layer.7.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039469280 -> 140201394377392 + 140201394377392 [label=AccumulateGrad] + 140201394339072 -> 140201394339264 + 140201394339072 [label=TBackward0] + 140201394339744 -> 140201394339072 + 140201394339744 [label=ToCopyBackward0] + 140201394339648 -> 140201394339744 + 140193039468720 [label="encoder.layer.7.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039468720 -> 140201394339648 + 140201394339648 [label=AccumulateGrad] + 140201394337776 -> 140193036629328 + 140201394338832 -> 140193036629856 + 140193039468480 [label="encoder.layer.7.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039468480 -> 140201394338832 + 140201394338832 [label=AccumulateGrad] + 140201394338496 -> 140193036629856 + 140193039468560 [label="encoder.layer.7.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039468560 -> 140201394338496 + 140201394338496 [label=AccumulateGrad] + 140193036628032 -> 140193036628176 + 140193036628032 [label=TBackward0] + 140193036628416 -> 140193036628032 + 140193036628416 [label=ToCopyBackward0] + 140193036629520 -> 140193036628416 + 140193039467360 [label="encoder.layer.8.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039467360 -> 140193036629520 + 140193036629520 [label=AccumulateGrad] + 140193036594144 -> 140193036594288 + 140193036594144 [label=UnsafeViewBackward0] + 140193036594960 -> 140193036594144 + 140193036594960 [label=CloneBackward0] + 140193036595152 -> 140193036594960 + 140193036595152 [label=ExpandBackward0] + 140193036628560 -> 140193036595152 + 140193036628560 [label=TransposeBackward0] + 140193036628848 -> 140193036628560 + 140193036628848 [label=PermuteBackward0] + 140193036628080 -> 140193036628848 + 140193036628080 [label=ViewBackward0] + 140201394339216 -> 140193036628080 + 140201394339216 [label=ViewBackward0] + 140201394339456 -> 140201394339216 + 140201394339456 [label=AddmmBackward0] + 140201394339120 -> 140201394339456 + 140201394339120 [label=ToCopyBackward0] + 140201394377104 -> 140201394339120 + 140193039467520 [label="encoder.layer.8.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039467520 -> 140201394377104 + 140201394377104 [label=AccumulateGrad] + 140201394376912 -> 140201394339456 + 140201394376912 [label=ViewBackward0] + 140201394377440 -> 140201394376912 + 140201394377440 [label=ToCopyBackward0] + 140193036570576 -> 140201394377440 + 140201394377152 -> 140201394339456 + 140201394377152 [label=TBackward0] + 140201394376864 -> 140201394377152 + 140201394376864 [label=ToCopyBackward0] + 140201394377584 -> 140201394376864 + 140193039467600 [label="encoder.layer.8.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039467600 -> 140201394377584 + 140201394377584 [label=AccumulateGrad] + 140193036592848 -> 140193036592656 + 140193036592848 [label=UnsafeViewBackward0] + 140193036593184 -> 140193036592848 + 140193036593184 [label=CloneBackward0] + 140193036593616 -> 140193036593184 + 140193036593616 [label=ExpandBackward0] + 140193036593904 -> 140193036593616 + 140193036593904 [label=PermuteBackward0] + 140193036593040 -> 140193036593904 + 140193036593040 [label=ViewBackward0] + 140193036594672 -> 140193036593040 + 140193036594672 [label=ViewBackward0] + 140193036629040 -> 140193036594672 + 140193036629040 [label=AddmmBackward0] + 140193036592944 -> 140193036629040 + 140193036592944 [label=ToCopyBackward0] + 140201394377344 -> 140193036592944 + 140193039466560 [label="encoder.layer.8.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039466560 -> 140201394377344 + 140201394377344 [label=AccumulateGrad] + 140201394338688 -> 140193036629040 + 140201394338688 [label=ViewBackward0] + 140201394377680 -> 140201394338688 + 140201394377680 [label=ToCopyBackward0] + 140193036570576 -> 140201394377680 + 140201394338928 -> 140193036629040 + 140201394338928 [label=TBackward0] + 140201394377248 -> 140201394338928 + 140201394377248 [label=ToCopyBackward0] + 140201394377728 -> 140201394377248 + 140193039466880 [label="encoder.layer.8.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039466880 -> 140201394377728 + 140201394377728 [label=AccumulateGrad] + 140193036591312 -> 140193036591600 + 140193036591312 [label=TBackward0] + 140193036592368 -> 140193036591312 + 140193036592368 [label=ToCopyBackward0] + 140193036592752 -> 140193036592368 + 140193039466640 [label="encoder.layer.8.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039466640 -> 140193036592752 + 140193036592752 [label=AccumulateGrad] + 140193036570576 -> 140193036570192 + 140193036570288 -> 140193036570000 + 140193039461008 [label="encoder.layer.8.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039461008 -> 140193036570288 + 140193036570288 [label=AccumulateGrad] + 140193036569232 -> 140193036570000 + 140193039460768 [label="encoder.layer.8.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039460768 -> 140193036569232 + 140193036569232 [label=AccumulateGrad] + 140193036567840 -> 140193036568800 + 140193036567840 [label=TBackward0] + 140193036569424 -> 140193036567840 + 140193036569424 [label=ToCopyBackward0] + 140193036570096 -> 140193036569424 + 140193039460448 [label="encoder.layer.8.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039460448 -> 140193036570096 + 140193036570096 [label=AccumulateGrad] + 140193036567792 -> 140193036567888 + 140193036567792 [label=UnsafeViewBackward0] + 140193036568464 -> 140193036567792 + 140193036568464 [label=CloneBackward0] + 140193036568848 -> 140193036568464 + 140193036568848 [label=ExpandBackward0] + 140193036569328 -> 140193036568848 + 140193036569328 [label=TransposeBackward0] + 140193036569904 -> 140193036569328 + 140193036569904 [label=PermuteBackward0] + 140193036570384 -> 140193036569904 + 140193036570384 [label=ViewBackward0] + 140193036569520 -> 140193036570384 + 140193036569520 [label=ViewBackward0] + 140193036591696 -> 140193036569520 + 140193036591696 [label=AddmmBackward0] + 140193036592560 -> 140193036591696 + 140193036592560 [label=ToCopyBackward0] + 140193036593424 -> 140193036592560 + 140193039459968 [label="encoder.layer.8.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140193039459968 -> 140193036593424 + 140193036593424 [label=AccumulateGrad] + 140193036591888 -> 140193036591696 + 140193036591888 [label=ViewBackward0] + 140193036593664 -> 140193036591888 + 140193036593664 [label=ToCopyBackward0] + 140193036594768 -> 140193036593664 + 140193036594768 [label=ViewBackward0] + 140201394339408 -> 140193036594768 + 140201394339408 [label=CloneBackward0] + 140193036592272 -> 140201394339408 + 140193036592272 [label=ExpandBackward0] + 140201394377776 -> 140193036592272 + 140201394377776 [label=UnsqueezeBackward0] + 140193036037232 -> 140201394377776 + 140193036591216 -> 140193036591696 + 140193036591216 [label=TBackward0] + 140193036628272 -> 140193036591216 + 140193036628272 [label=ToCopyBackward0] + 140193036594096 -> 140193036628272 + 140193039460288 [label="encoder.layer.8.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140193039460288 -> 140193036594096 + 140193036594096 [label=AccumulateGrad] + 140193036537712 -> 140193036537424 + 140193036537712 [label=UnsafeViewBackward0] + 140193036537808 -> 140193036537712 + 140193036537808 [label=CloneBackward0] + 140193036567120 -> 140193036537808 + 140193036567120 [label=ExpandBackward0] + 140193036567360 -> 140193036567120 + 140193036567360 [label=PermuteBackward0] + 140193036566736 -> 140193036567360 + 140193036566736 [label=ViewBackward0] + 140193036568656 -> 140193036566736 + 140193036568656 [label=ViewBackward0] + 140193036569808 -> 140193036568656 + 140193036569808 [label=AddmmBackward0] + 140193036568176 -> 140193036569808 + 140193036568176 [label=ToCopyBackward0] + 140193036591744 -> 140193036568176 + 140193039459728 [label="encoder.layer.8.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140193039459728 -> 140193036591744 + 140193036591744 [label=AccumulateGrad] + 140193036570240 -> 140193036569808 + 140193036570240 [label=ViewBackward0] + 140193036593136 -> 140193036570240 + 140193036593136 [label=ToCopyBackward0] + 140193036594768 -> 140193036593136 + 140193036566640 -> 140193036569808 + 140193036566640 [label=TBackward0] + 140201394377824 -> 140193036566640 + 140201394377824 [label=ToCopyBackward0] + 140201394377488 -> 140201394377824 + 140193039460048 [label="encoder.layer.8.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140193039460048 -> 140201394377488 + 140201394377488 [label=AccumulateGrad] + 140193036536080 -> 140193036536368 + 140193036536080 [label=TBackward0] + 140193036537232 -> 140193036536080 + 140193036537232 [label=ToCopyBackward0] + 140193036537520 -> 140193036537232 + 140193039459808 [label="encoder.layer.8.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039459808 -> 140193036537520 + 140193036537520 [label=AccumulateGrad] + 140193036535888 -> 140193036535504 + 140193036535600 -> 140193036535264 + 140193039459568 [label="encoder.layer.8.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039459568 -> 140193036535600 + 140193036535600 [label=AccumulateGrad] + 140193036535024 -> 140193036535264 + 140193039459248 [label="encoder.layer.8.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039459248 -> 140193036535024 + 140193036535024 [label=AccumulateGrad] + 140193036533872 -> 140193036534256 + 140193036533872 [label=TBackward0] + 140193036535312 -> 140193036533872 + 140193036535312 [label=ToCopyBackward0] + 140193036535408 -> 140193036535312 + 140193039443744 [label="encoder.layer.8.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039443744 -> 140193036535408 + 140193036535408 [label=AccumulateGrad] + 140193036516560 -> 140193036517040 + 140193036516560 [label=TBackward0] + 140193036534352 -> 140193036516560 + 140193036534352 [label=ToCopyBackward0] + 140193036534544 -> 140193036534352 + 140193039443824 [label="encoder.layer.8.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039443824 -> 140193036534544 + 140193036534544 [label=AccumulateGrad] + 140193036516464 -> 140193036516176 + 140193036516464 [label=UnsqueezeBackward0] + 140193036516848 -> 140193036516464 + 140193036516848 [label=NativeDropoutBackward0] + 140193036517136 -> 140193036516848 + 140193036517136 [label=ViewBackward0] + 140193036535696 -> 140193036517136 + 140193036535696 [label=AddmmBackward0] + 140193036534160 -> 140193036535696 + 140193036534160 [label=ToCopyBackward0] + 140193036536176 -> 140193036534160 + 140193039443504 [label="encoder.layer.8.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140193039443504 -> 140193036536176 + 140193036536176 [label=AccumulateGrad] + 140193036534832 -> 140193036535696 + 140193036534832 [label=ViewBackward0] + 140193036536464 -> 140193036534832 + 140193036536464 [label=GeluBackward0] + 140193036537328 -> 140193036536464 + 140193036537328 [label=ViewBackward0] + 140193036537040 -> 140193036537328 + 140193036537040 [label=AddmmBackward0] + 140193036535984 -> 140193036537040 + 140193036535984 [label=ToCopyBackward0] + 140193036568944 -> 140193036535984 + 140193039444304 [label="encoder.layer.8.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140193039444304 -> 140193036568944 + 140193036568944 [label=AccumulateGrad] + 140193036566880 -> 140193036537040 + 140193036566880 [label=ViewBackward0] + 140193036567600 -> 140193036566880 + 140193036567600 [label=ToCopyBackward0] + 140193036535120 -> 140193036567600 + 140193036566832 -> 140193036537040 + 140193036566832 [label=TBackward0] + 140193036591264 -> 140193036566832 + 140193036591264 [label=ToCopyBackward0] + 140201394377968 -> 140193036591264 + 140193039443584 [label="encoder.layer.8.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039443584 -> 140201394377968 + 140201394377968 [label=AccumulateGrad] + 140193036533824 -> 140193036535696 + 140193036533824 [label=TBackward0] + 140193036536656 -> 140193036533824 + 140193036536656 [label=ToCopyBackward0] + 140193036568368 -> 140193036536656 + 140193039443344 [label="encoder.layer.8.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039443344 -> 140193036568368 + 140193036568368 [label=AccumulateGrad] + 140193036515696 -> 140193036515792 + 140193036515696 [label=UnsqueezeBackward0] + 140193036515888 -> 140193036515696 + 140193036515888 [label=UnsqueezeBackward0] + 140193036516416 -> 140193036515888 + 140193036516416 [label=MulBackward0] + 140193036567312 -> 140193036516416 + 140193036567312 [label=IndexBackward0] + 140193036534784 -> 140193036567312 + 140193036534784 [label=SoftmaxBackward0] + 140193036536752 -> 140193036534784 + 140193036536752 [label=CatBackward0] + 140201394377920 -> 140193036536752 + 140201394377920 [label=MmBackward0] + 140201394378064 -> 140201394377920 + 140201394378064 [label=MeanBackward1] + 140193036516752 -> 140201394378064 + 140201394378016 -> 140201394377920 + 140201394378016 [label=TBackward0] + 140201394378112 -> 140201394378016 + 140201394378112 [label=ToCopyBackward0] + 140201394378304 -> 140201394378112 + 140193039444704 [label="encoder.layer.8.experts.gate.weight + (1, 768)" fillcolor=lightblue] + 140193039444704 -> 140201394378304 + 140201394378304 [label=AccumulateGrad] + 140201394377536 -> 140193036536752 + 140201394377536 [label=MmBackward0] + 140201394378256 -> 140201394377536 + 140201394378256 [label=MeanBackward1] + 140193036516848 -> 140201394378256 + 140201394378160 -> 140201394377536 + 140201394378160 [label=TBackward0] + 140201394378112 -> 140201394378160 + 140193036515024 -> 140193036514640 + 140193036515024 [label=UnsqueezeBackward0] + 140193036515408 -> 140193036515024 + 140193036515408 [label=SelectBackward0] + 140193036515120 -> 140193036515408 + 140193036515120 [label=NativeDropoutBackward0] + 140193036516656 -> 140193036515120 + 140193036516656 [label=ViewBackward0] + 140193036516080 -> 140193036516656 + 140193036516080 [label=AddmmBackward0] + 140193036534640 -> 140193036516080 + 140193036534640 [label=ToCopyBackward0] + 140201394378208 -> 140193036534640 + 140193039443024 [label="encoder.layer.8.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140193039443024 -> 140201394378208 + 140201394378208 [label=AccumulateGrad] + 140193036535792 -> 140193036516080 + 140193036535792 [label=ViewBackward0] + 140201394378448 -> 140193036535792 + 140201394378448 [label=GeluBackward0] + 140201394378544 -> 140201394378448 + 140201394378544 [label=ViewBackward0] + 140201394378640 -> 140201394378544 + 140201394378640 [label=AddmmBackward0] + 140201394378736 -> 140201394378640 + 140201394378736 [label=ToCopyBackward0] + 140201394378928 -> 140201394378736 + 140193039443264 [label="encoder.layer.8.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140193039443264 -> 140201394378928 + 140201394378928 [label=AccumulateGrad] + 140201394378688 -> 140201394378640 + 140201394378688 [label=ViewBackward0] + 140201394378976 -> 140201394378688 + 140201394378976 [label=ToCopyBackward0] + 140201394379072 -> 140201394378976 + 140201394379072 [label=IndexBackward0] + 140193036514352 -> 140201394379072 + 140201394378352 -> 140201394378640 + 140201394378352 [label=TBackward0] + 140201394379168 -> 140201394378352 + 140201394379168 [label=ToCopyBackward0] + 140201394378880 -> 140201394379168 + 140193039443104 [label="encoder.layer.8.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039443104 -> 140201394378880 + 140201394378880 [label=AccumulateGrad] + 140201394377632 -> 140193036516080 + 140201394377632 [label=TBackward0] + 140201394378592 -> 140201394377632 + 140201394378592 [label=ToCopyBackward0] + 140201394378832 -> 140201394378592 + 140193039442864 [label="encoder.layer.8.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039442864 -> 140201394378832 + 140201394378832 [label=AccumulateGrad] + 140193036514448 -> 140193036514256 + 140193036514448 [label=CatBackward0] + 140193036515456 -> 140193036514448 + 140193036515456 [label=SliceBackward0] + 140193036514976 -> 140193036515456 + 140193036514976 [label=SliceBackward0] + 140201394378784 -> 140193036514976 + 140201394378784 [label=SliceBackward0] + 140193036515600 -> 140201394378784 + 140193036515216 -> 140193036514448 + 140193036515216 [label=UnsqueezeBackward0] + 140193036516272 -> 140193036515216 + 140193036516272 [label=SelectBackward0] + 140193036515120 -> 140193036516272 + 140193036514544 -> 140193036514256 + 140193036514544 [label=CatBackward0] + 140193036514736 -> 140193036514544 + 140193036514736 [label=SliceBackward0] + 140201394378496 -> 140193036514736 + 140201394378496 [label=SliceBackward0] + 140201394379024 -> 140201394378496 + 140201394379024 [label=SliceBackward0] + 140193036515600 -> 140201394379024 + 140201394377872 -> 140193036514544 + 140201394377872 [label=UnsqueezeBackward0] + 140201394379360 -> 140201394377872 + 140201394379360 [label=SelectBackward0] + 140193036515120 -> 140201394379360 + 140193036514496 -> 140193036514256 + 140193036514496 [label=CatBackward0] + 140201394379264 -> 140193036514496 + 140201394379264 [label=SliceBackward0] + 140201394379408 -> 140201394379264 + 140201394379408 [label=SliceBackward0] + 140201394379504 -> 140201394379408 + 140201394379504 [label=SliceBackward0] + 140193036515600 -> 140201394379504 + 140201394379216 -> 140193036514496 + 140201394379216 [label=UnsqueezeBackward0] + 140201394379600 -> 140201394379216 + 140201394379600 [label=SelectBackward0] + 140193036515120 -> 140201394379600 + 140193036514352 -> 140193036513968 + 140193036514064 -> 140193036513536 + 140193039445024 [label="encoder.layer.8.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039445024 -> 140193036514064 + 140193036514064 [label=AccumulateGrad] + 140193036513776 -> 140193036513536 + 140193039444784 [label="encoder.layer.8.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039444784 -> 140193036513776 + 140193036513776 [label=AccumulateGrad] + 140193036513392 -> 140193036975408 + 140193036513392 [label=IndexBackward0] + 140193036514832 -> 140193036513392 + 140193036514832 [label=NativeLayerNormBackward0] + 140193036514160 -> 140193036514832 + 140193036514160 [label=AddBackward0] + 140201394379648 -> 140193036514160 + 140201394379648 [label=NativeDropoutBackward0] + 140201394379792 -> 140201394379648 + 140201394379792 [label=ViewBackward0] + 140201394379888 -> 140201394379792 + 140201394379888 [label=AddmmBackward0] + 140201394379984 -> 140201394379888 + 140201394379984 [label=ToCopyBackward0] + 140201394380176 -> 140201394379984 + 140193039458768 [label="encoder.layer.8.output.dense.bias + (768)" fillcolor=lightblue] + 140193039458768 -> 140201394380176 + 140201394380176 [label=AccumulateGrad] + 140201394379936 -> 140201394379888 + 140201394379936 [label=ViewBackward0] + 140201394380224 -> 140201394379936 + 140201394380224 [label=GeluBackward0] + 140201394380320 -> 140201394380224 + 140201394380320 [label=ViewBackward0] + 140201394380416 -> 140201394380320 + 140201394380416 [label=AddmmBackward0] + 140201394380512 -> 140201394380416 + 140201394380512 [label=ToCopyBackward0] + 140201394380704 -> 140201394380512 + 140193039459008 [label="encoder.layer.8.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039459008 -> 140201394380704 + 140201394380704 [label=AccumulateGrad] + 140201394380464 -> 140201394380416 + 140201394380464 [label=ViewBackward0] + 140201394380608 -> 140201394380464 + 140201394380608 [label=ToCopyBackward0] + 140201394378400 -> 140201394380608 + 140201394378400 [label=SliceBackward0] + 140201394430112 -> 140201394378400 + 140201394430112 [label=SliceBackward0] + 140201394430208 -> 140201394430112 + 140201394430208 [label=SliceBackward0] + 140193036570000 -> 140201394430208 + 140201394380128 -> 140201394380416 + 140201394380128 [label=TBackward0] + 140201394380656 -> 140201394380128 + 140201394380656 [label=ToCopyBackward0] + 140201394430304 -> 140201394380656 + 140193039459328 [label="encoder.layer.8.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039459328 -> 140201394430304 + 140201394430304 [label=AccumulateGrad] + 140201394379696 -> 140201394379888 + 140201394379696 [label=TBackward0] + 140201394380368 -> 140201394379696 + 140201394380368 [label=ToCopyBackward0] + 140201394380752 -> 140201394380368 + 140193039459088 [label="encoder.layer.8.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039459088 -> 140201394380752 + 140201394380752 [label=AccumulateGrad] + 140201394378400 -> 140193036514160 + 140201394379456 -> 140193036514832 + 140193039458848 [label="encoder.layer.8.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039458848 -> 140201394379456 + 140201394379456 [label=AccumulateGrad] + 140201394379120 -> 140193036514832 + 140193039458528 [label="encoder.layer.8.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039458528 -> 140201394379120 + 140201394379120 [label=AccumulateGrad] + 140193036483120 -> 140193036484080 + 140193036483120 [label=TBackward0] + 140193036513344 -> 140193036483120 + 140193036513344 [label=ToCopyBackward0] + 140193036514016 -> 140193036513344 + 140193039444944 [label="encoder.layer.9.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039444944 -> 140193036514016 + 140193036514016 [label=AccumulateGrad] + 140193036482880 -> 140193036483024 + 140193036482880 [label=UnsafeViewBackward0] + 140193036483696 -> 140193036482880 + 140193036483696 [label=CloneBackward0] + 140193036483984 -> 140193036483696 + 140193036483984 [label=ExpandBackward0] + 140193036484464 -> 140193036483984 + 140193036484464 [label=TransposeBackward0] + 140193036513488 -> 140193036484464 + 140193036513488 [label=PermuteBackward0] + 140193036483408 -> 140193036513488 + 140193036483408 [label=ViewBackward0] + 140201394379840 -> 140193036483408 + 140201394379840 [label=ViewBackward0] + 140201394380080 -> 140201394379840 + 140201394380080 [label=AddmmBackward0] + 140201394380272 -> 140201394380080 + 140201394380272 [label=ToCopyBackward0] + 140201394430064 -> 140201394380272 + 140193039445504 [label="encoder.layer.9.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039445504 -> 140201394430064 + 140201394430064 [label=AccumulateGrad] + 140201394380560 -> 140201394380080 + 140201394380560 [label=ViewBackward0] + 140201394430352 -> 140201394380560 + 140201394430352 [label=ToCopyBackward0] + 140193036975408 -> 140201394430352 + 140201394379744 -> 140201394380080 + 140201394379744 [label=TBackward0] + 140201394430160 -> 140201394379744 + 140201394430160 [label=ToCopyBackward0] + 140201394430496 -> 140201394430160 + 140193039445184 [label="encoder.layer.9.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039445184 -> 140201394430496 + 140201394430496 [label=AccumulateGrad] + 140193036481584 -> 140193036481392 + 140193036481584 [label=UnsafeViewBackward0] + 140193036481920 -> 140193036481584 + 140193036481920 [label=CloneBackward0] + 140193036482352 -> 140193036481920 + 140193036482352 [label=ExpandBackward0] + 140193036482640 -> 140193036482352 + 140193036482640 [label=PermuteBackward0] + 140193036481776 -> 140193036482640 + 140193036481776 [label=ViewBackward0] + 140193036483888 -> 140193036481776 + 140193036483888 [label=ViewBackward0] + 140193036513680 -> 140193036483888 + 140193036513680 [label=AddmmBackward0] + 140193036481680 -> 140193036513680 + 140193036481680 [label=ToCopyBackward0] + 140201394430256 -> 140193036481680 + 140193039444544 [label="encoder.layer.9.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039444544 -> 140201394430256 + 140201394430256 [label=AccumulateGrad] + 140201394379312 -> 140193036513680 + 140201394379312 [label=ViewBackward0] + 140201394430592 -> 140201394379312 + 140201394430592 [label=ToCopyBackward0] + 140193036975408 -> 140201394430592 + 140201394379552 -> 140193036513680 + 140201394379552 [label=TBackward0] + 140201394430016 -> 140201394379552 + 140201394430016 [label=ToCopyBackward0] + 140201394430640 -> 140201394430016 + 140193039444464 [label="encoder.layer.9.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039444464 -> 140201394430640 + 140201394430640 [label=AccumulateGrad] + 140193036975600 -> 140193036975888 + 140193036975600 [label=TBackward0] + 140193036481104 -> 140193036975600 + 140193036481104 [label=ToCopyBackward0] + 140193036481488 -> 140193036481104 + 140193039444224 [label="encoder.layer.9.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039444224 -> 140193036481488 + 140193036481488 [label=AccumulateGrad] + 140193036975408 -> 140193036975024 + 140193036975120 -> 140193036974832 + 140193039442784 [label="encoder.layer.9.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039442784 -> 140193036975120 + 140193036975120 [label=AccumulateGrad] + 140193036974064 -> 140193036974832 + 140193039442544 [label="encoder.layer.9.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039442544 -> 140193036974064 + 140193036974064 [label=AccumulateGrad] + 140193036972672 -> 140193036973152 + 140193036972672 [label=TBackward0] + 140193036974256 -> 140193036972672 + 140193036974256 [label=ToCopyBackward0] + 140193036974448 -> 140193036974256 + 140193039431216 [label="encoder.layer.9.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039431216 -> 140193036974448 + 140193036974448 [label=AccumulateGrad] + 140193036972240 -> 140193036972528 + 140193036972240 [label=TBackward0] + 140193036973296 -> 140193036972240 + 140193036973296 [label=ToCopyBackward0] + 140193036973584 -> 140193036973296 + 140193039430896 [label="encoder.layer.9.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039430896 -> 140193036973584 + 140193036973584 [label=AccumulateGrad] + 140193036951312 -> 140193036951120 + 140193036951312 [label=UnsqueezeBackward0] + 140193036951504 -> 140193036951312 + 140193036951504 [label=NativeDropoutBackward0] + 140193036972816 -> 140193036951504 + 140193036972816 [label=ViewBackward0] + 140193036974736 -> 140193036972816 + 140193036974736 [label=AddmmBackward0] + 140193036973200 -> 140193036974736 + 140193036973200 [label=ToCopyBackward0] + 140193036975216 -> 140193036973200 + 140193039430976 [label="encoder.layer.9.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140193039430976 -> 140193036975216 + 140193036975216 [label=AccumulateGrad] + 140193036973776 -> 140193036974736 + 140193036973776 [label=ViewBackward0] + 140193036975072 -> 140193036973776 + 140193036975072 [label=GeluBackward0] + 140193036975552 -> 140193036975072 + 140193036975552 [label=ViewBackward0] + 140193036975984 -> 140193036975552 + 140193036975984 [label=AddmmBackward0] + 140193036481872 -> 140193036975984 + 140193036481872 [label=ToCopyBackward0] + 140193036482400 -> 140193036481872 + 140193039431376 [label="encoder.layer.9.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140193039431376 -> 140193036482400 + 140193036482400 [label=AccumulateGrad] + 140193036481296 -> 140193036975984 + 140193036481296 [label=ViewBackward0] + 140193036482832 -> 140193036481296 + 140193036482832 [label=ToCopyBackward0] + 140193036974160 -> 140193036482832 + 140193036480576 -> 140193036975984 + 140193036480576 [label=TBackward0] + 140193036481008 -> 140193036480576 + 140193036481008 [label=ToCopyBackward0] + 140201394380032 -> 140193036481008 + 140193039430656 [label="encoder.layer.9.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039430656 -> 140201394380032 + 140201394380032 [label=AccumulateGrad] + 140193036972144 -> 140193036974736 + 140193036972144 [label=TBackward0] + 140193036974928 -> 140193036972144 + 140193036974928 [label=ToCopyBackward0] + 140193036975504 -> 140193036974928 + 140193039430416 [label="encoder.layer.9.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039430416 -> 140193036975504 + 140193036975504 [label=AccumulateGrad] + 140193036950640 -> 140193036950736 + 140193036950640 [label=UnsqueezeBackward0] + 140193036950688 -> 140193036950640 + 140193036950688 [label=UnsqueezeBackward0] + 140193036951024 -> 140193036950688 + 140193036951024 [label=MulBackward0] + 140193036974640 -> 140193036951024 + 140193036974640 [label=IndexBackward0] + 140193036973872 -> 140193036974640 + 140193036973872 [label=SoftmaxBackward0] + 140193036482160 -> 140193036973872 + 140193036482160 [label=CatBackward0] + 140193036480624 -> 140193036482160 + 140193036480624 [label=MmBackward0] + 140201394430688 -> 140193036480624 + 140201394430688 [label=MeanBackward1] + 140193036951408 -> 140201394430688 + 140201394430736 -> 140193036480624 + 140201394430736 [label=TBackward0] + 140201394430544 -> 140201394430736 + 140201394430544 [label=ToCopyBackward0] + 140201394430928 -> 140201394430544 + 140193039432176 [label="encoder.layer.9.experts.gate.weight + (1, 768)" fillcolor=lightblue] + 140193039432176 -> 140201394430928 + 140201394430928 [label=AccumulateGrad] + 140193036484176 -> 140193036482160 + 140193036484176 [label=MmBackward0] + 140201394430880 -> 140193036484176 + 140201394430880 [label=MeanBackward1] + 140193036951504 -> 140201394430880 + 140201394430784 -> 140193036484176 + 140201394430784 [label=TBackward0] + 140201394430544 -> 140201394430784 + 140193036949872 -> 140193036949584 + 140193036949872 [label=UnsqueezeBackward0] + 140193036950208 -> 140193036949872 + 140193036950208 [label=SelectBackward0] + 140193036950064 -> 140193036950208 + 140193036950064 [label=NativeDropoutBackward0] + 140193036951216 -> 140193036950064 + 140193036951216 [label=ViewBackward0] + 140193036972192 -> 140193036951216 + 140193036972192 [label=AddmmBackward0] + 140193036483504 -> 140193036972192 + 140193036483504 [label=ToCopyBackward0] + 140201394430832 -> 140193036483504 + 140193039430496 [label="encoder.layer.9.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140193039430496 -> 140201394430832 + 140201394430832 [label=AccumulateGrad] + 140193036972624 -> 140193036972192 + 140193036972624 [label=ViewBackward0] + 140201394431072 -> 140193036972624 + 140201394431072 [label=GeluBackward0] + 140201394431168 -> 140201394431072 + 140201394431168 [label=ViewBackward0] + 140201394431264 -> 140201394431168 + 140201394431264 [label=AddmmBackward0] + 140201394431360 -> 140201394431264 + 140201394431360 [label=ToCopyBackward0] + 140201394431552 -> 140201394431360 + 140193039430736 [label="encoder.layer.9.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140193039430736 -> 140201394431552 + 140201394431552 [label=AccumulateGrad] + 140201394431312 -> 140201394431264 + 140201394431312 [label=ViewBackward0] + 140201394431600 -> 140201394431312 + 140201394431600 [label=ToCopyBackward0] + 140201394431696 -> 140201394431600 + 140201394431696 [label=IndexBackward0] + 140193036949296 -> 140201394431696 + 140201394430976 -> 140201394431264 + 140201394430976 [label=TBackward0] + 140201394431792 -> 140201394430976 + 140201394431792 [label=ToCopyBackward0] + 140201394431504 -> 140201394431792 + 140193039430176 [label="encoder.layer.9.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039430176 -> 140201394431504 + 140201394431504 [label=AccumulateGrad] + 140201394430448 -> 140193036972192 + 140201394430448 [label=TBackward0] + 140201394431216 -> 140201394430448 + 140201394431216 [label=ToCopyBackward0] + 140201394431456 -> 140201394431216 + 140193039429936 [label="encoder.layer.9.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039429936 -> 140201394431456 + 140201394431456 [label=AccumulateGrad] + 140193036949248 -> 140193036949200 + 140193036949248 [label=CatBackward0] + 140193036950448 -> 140193036949248 + 140193036950448 [label=SliceBackward0] + 140193036973680 -> 140193036950448 + 140193036973680 [label=SliceBackward0] + 140201394431408 -> 140193036973680 + 140201394431408 [label=SliceBackward0] + 140193036950544 -> 140201394431408 + 140193036950160 -> 140193036949248 + 140193036950160 [label=UnsqueezeBackward0] + 140193036949968 -> 140193036950160 + 140193036949968 [label=SelectBackward0] + 140193036950064 -> 140193036949968 + 140193036949392 -> 140193036949200 + 140193036949392 [label=CatBackward0] + 140193036949680 -> 140193036949392 + 140193036949680 [label=SliceBackward0] + 140201394431120 -> 140193036949680 + 140201394431120 [label=SliceBackward0] + 140201394431648 -> 140201394431120 + 140201394431648 [label=SliceBackward0] + 140193036950544 -> 140201394431648 + 140201394430400 -> 140193036949392 + 140201394430400 [label=UnsqueezeBackward0] + 140201394431984 -> 140201394430400 + 140201394431984 [label=SelectBackward0] + 140193036950064 -> 140201394431984 + 140193036949488 -> 140193036949200 + 140193036949488 [label=CatBackward0] + 140201394431888 -> 140193036949488 + 140201394431888 [label=SliceBackward0] + 140201394432032 -> 140201394431888 + 140201394432032 [label=SliceBackward0] + 140201394432128 -> 140201394432032 + 140201394432128 [label=SliceBackward0] + 140193036950544 -> 140201394432128 + 140201394431840 -> 140193036949488 + 140201394431840 [label=UnsqueezeBackward0] + 140201394432224 -> 140201394431840 + 140201394432224 [label=SelectBackward0] + 140193036950064 -> 140201394432224 + 140193036949296 -> 140193036948768 + 140193036948912 -> 140193036948816 + 140193039432096 [label="encoder.layer.9.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039432096 -> 140193036948912 + 140193036948912 [label=AccumulateGrad] + 140193036948720 -> 140193036948816 + 140193039431856 [label="encoder.layer.9.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039431856 -> 140193036948720 + 140193036948720 [label=AccumulateGrad] + 140193036948240 -> 140193036914800 + 140193036948240 [label=IndexBackward0] + 140193036949776 -> 140193036948240 + 140193036949776 [label=NativeLayerNormBackward0] + 140193036949104 -> 140193036949776 + 140193036949104 [label=AddBackward0] + 140201394432272 -> 140193036949104 + 140201394432272 [label=NativeDropoutBackward0] + 140201394432416 -> 140201394432272 + 140201394432416 [label=ViewBackward0] + 140201394432512 -> 140201394432416 + 140201394432512 [label=AddmmBackward0] + 140201394432608 -> 140201394432512 + 140201394432608 [label=ToCopyBackward0] + 140201394432800 -> 140201394432608 + 140193039442144 [label="encoder.layer.9.output.dense.bias + (768)" fillcolor=lightblue] + 140193039442144 -> 140201394432800 + 140201394432800 [label=AccumulateGrad] + 140201394432560 -> 140201394432512 + 140201394432560 [label=ViewBackward0] + 140201394432848 -> 140201394432560 + 140201394432848 [label=GeluBackward0] + 140201394432944 -> 140201394432848 + 140201394432944 [label=ViewBackward0] + 140201394433040 -> 140201394432944 + 140201394433040 [label=AddmmBackward0] + 140201394433136 -> 140201394433040 + 140201394433136 [label=ToCopyBackward0] + 140201394433328 -> 140201394433136 + 140193039442384 [label="encoder.layer.9.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039442384 -> 140201394433328 + 140201394433328 [label=AccumulateGrad] + 140201394433088 -> 140201394433040 + 140201394433088 [label=ViewBackward0] + 140201394433376 -> 140201394433088 + 140201394433376 [label=ToCopyBackward0] + 140201394431024 -> 140201394433376 + 140201394431024 [label=SliceBackward0] + 140201394433520 -> 140201394431024 + 140201394433520 [label=SliceBackward0] + 140201394433616 -> 140201394433520 + 140201394433616 [label=SliceBackward0] + 140193036974832 -> 140201394433616 + 140201394432752 -> 140201394433040 + 140201394432752 [label=TBackward0] + 140201394433280 -> 140201394432752 + 140201394433280 [label=ToCopyBackward0] + 140201394433712 -> 140201394433280 + 140193039442624 [label="encoder.layer.9.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039442624 -> 140201394433712 + 140201394433712 [label=AccumulateGrad] + 140201394432320 -> 140201394432512 + 140201394432320 [label=TBackward0] + 140201394432992 -> 140201394432320 + 140201394432992 [label=ToCopyBackward0] + 140201394433472 -> 140201394432992 + 140193039442064 [label="encoder.layer.9.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039442064 -> 140201394433472 + 140201394433472 [label=AccumulateGrad] + 140201394431024 -> 140193036949104 + 140201394432080 -> 140193036949776 + 140193039433536 [label="encoder.layer.9.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039433536 -> 140201394432080 + 140201394432080 [label=AccumulateGrad] + 140201394431744 -> 140193036949776 + 140193039433616 [label="encoder.layer.9.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039433616 -> 140201394431744 + 140201394431744 [label=AccumulateGrad] + 140193036947664 -> 140193036947952 + 140193036947664 [label=TBackward0] + 140193036948336 -> 140193036947664 + 140193036948336 [label=ToCopyBackward0] + 140193036949008 -> 140193036948336 + 140193039432416 [label="encoder.layer.10.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039432416 -> 140193036949008 + 140193036949008 [label=AccumulateGrad] + 140193036918160 -> 140193036917872 + 140193036918160 [label=UnsafeViewBackward0] + 140193036918544 -> 140193036918160 + 140193036918544 [label=CloneBackward0] + 140193036918256 -> 140193036918544 + 140193036918256 [label=ExpandBackward0] + 140193036948048 -> 140193036918256 + 140193036948048 [label=TransposeBackward0] + 140193036948624 -> 140193036948048 + 140193036948624 [label=PermuteBackward0] + 140193036947568 -> 140193036948624 + 140193036947568 [label=ViewBackward0] + 140201394432464 -> 140193036947568 + 140201394432464 [label=ViewBackward0] + 140201394432704 -> 140201394432464 + 140201394432704 [label=AddmmBackward0] + 140201394433232 -> 140201394432704 + 140201394433232 [label=ToCopyBackward0] + 140201394433424 -> 140201394433232 + 140193039432576 [label="encoder.layer.10.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039432576 -> 140201394433424 + 140201394433424 [label=AccumulateGrad] + 140201394433184 -> 140201394432704 + 140201394433184 [label=ViewBackward0] + 140201394433760 -> 140201394433184 + 140201394433760 [label=ToCopyBackward0] + 140193036914800 -> 140201394433760 + 140201394432368 -> 140201394432704 + 140201394432368 [label=TBackward0] + 140201394432896 -> 140201394432368 + 140201394432896 [label=ToCopyBackward0] + 140201394433904 -> 140201394432896 + 140193039432656 [label="encoder.layer.10.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039432656 -> 140201394433904 + 140201394433904 [label=AccumulateGrad] + 140193036916432 -> 140193036916528 + 140193036916432 [label=UnsafeViewBackward0] + 140193036917200 -> 140193036916432 + 140193036917200 [label=CloneBackward0] + 140193036917488 -> 140193036917200 + 140193036917488 [label=ExpandBackward0] + 140193036917776 -> 140193036917488 + 140193036917776 [label=PermuteBackward0] + 140193036916624 -> 140193036917776 + 140193036916624 [label=ViewBackward0] + 140193036918112 -> 140193036916624 + 140193036918112 [label=ViewBackward0] + 140193036948528 -> 140193036918112 + 140193036948528 [label=AddmmBackward0] + 140193036916816 -> 140193036948528 + 140193036916816 [label=ToCopyBackward0] + 140201394433664 -> 140193036916816 + 140193039431616 [label="encoder.layer.10.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039431616 -> 140201394433664 + 140201394433664 [label=AccumulateGrad] + 140201394431936 -> 140193036948528 + 140201394431936 [label=ViewBackward0] + 140201394434000 -> 140201394431936 + 140201394434000 [label=ToCopyBackward0] + 140193036914800 -> 140201394434000 + 140201394432176 -> 140193036948528 + 140201394432176 [label=TBackward0] + 140201394433568 -> 140201394432176 + 140201394433568 [label=ToCopyBackward0] + 140201394433856 -> 140201394433568 + 140193039431936 [label="encoder.layer.10.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039431936 -> 140201394433856 + 140201394433856 [label=AccumulateGrad] + 140193036914896 -> 140193036915184 + 140193036914896 [label=TBackward0] + 140193036915952 -> 140193036914896 + 140193036915952 [label=ToCopyBackward0] + 140193036916192 -> 140193036915952 + 140193039431696 [label="encoder.layer.10.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039431696 -> 140193036916192 + 140193036916192 [label=AccumulateGrad] + 140193036914800 -> 140193036894064 + 140193036893728 -> 140193036893872 + 140193039430256 [label="encoder.layer.10.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039430256 -> 140193036893728 + 140193036893728 [label=AccumulateGrad] + 140193036893104 -> 140193036893872 + 140193039430016 [label="encoder.layer.10.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039430016 -> 140193036893104 + 140193036893104 [label=AccumulateGrad] + 140193036891856 -> 140193036892816 + 140193036891856 [label=TBackward0] + 140193036893008 -> 140193036891856 + 140193036893008 [label=ToCopyBackward0] + 140193036893680 -> 140193036893008 + 140193039429696 [label="encoder.layer.10.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039429696 -> 140193036893680 + 140193036893680 [label=AccumulateGrad] + 140193036891664 -> 140193036891328 + 140193036891664 [label=UnsafeViewBackward0] + 140193036892048 -> 140193036891664 + 140193036892048 [label=CloneBackward0] + 140193036892288 -> 140193036892048 + 140193036892288 [label=ExpandBackward0] + 140193036892768 -> 140193036892288 + 140193036892768 [label=TransposeBackward0] + 140193036893488 -> 140193036892768 + 140193036893488 [label=PermuteBackward0] + 140193036893392 -> 140193036893488 + 140193036893392 [label=ViewBackward0] + 140193036891760 -> 140193036893392 + 140193036891760 [label=ViewBackward0] + 140193036915568 -> 140193036891760 + 140193036915568 [label=AddmmBackward0] + 140193036916144 -> 140193036915568 + 140193036916144 [label=ToCopyBackward0] + 140193036917296 -> 140193036916144 + 140193039420928 [label="encoder.layer.10.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140193039420928 -> 140193036917296 + 140193036917296 [label=AccumulateGrad] + 140193036915472 -> 140193036915568 + 140193036915472 [label=ViewBackward0] + 140193036917680 -> 140193036915472 + 140193036917680 [label=ToCopyBackward0] + 140193036918352 -> 140193036917680 + 140193036918352 [label=ViewBackward0] + 140193036947760 -> 140193036918352 + 140193036947760 [label=CloneBackward0] + 140201394433952 -> 140193036947760 + 140201394433952 [label=ExpandBackward0] + 140201394432656 -> 140201394433952 + 140201394432656 [label=UnsqueezeBackward0] + 140193036037232 -> 140201394432656 + 140193036914752 -> 140193036915568 + 140193036914752 [label=TBackward0] + 140193036915712 -> 140193036914752 + 140193036915712 [label=ToCopyBackward0] + 140201394433808 -> 140193036915712 + 140193039421248 [label="encoder.layer.10.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140193039421248 -> 140201394433808 + 140201394433808 [label=AccumulateGrad] + 140193036852288 -> 140193036853104 + 140193036852288 [label=UnsafeViewBackward0] + 140193036890704 -> 140193036852288 + 140193036890704 [label=CloneBackward0] + 140193036890992 -> 140193036890704 + 140193036890992 [label=ExpandBackward0] + 140193036891376 -> 140193036890992 + 140193036891376 [label=PermuteBackward0] + 140193036890224 -> 140193036891376 + 140193036890224 [label=ViewBackward0] + 140193036892240 -> 140193036890224 + 140193036892240 [label=ViewBackward0] + 140193036893248 -> 140193036892240 + 140193036893248 [label=AddmmBackward0] + 140193036893968 -> 140193036893248 + 140193036893968 [label=ToCopyBackward0] + 140193036917968 -> 140193036893968 + 140193039420688 [label="encoder.layer.10.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140193039420688 -> 140193036917968 + 140193036917968 [label=AccumulateGrad] + 140193036890416 -> 140193036893248 + 140193036890416 [label=ViewBackward0] + 140193036917008 -> 140193036890416 + 140193036917008 [label=ToCopyBackward0] + 140193036918352 -> 140193036917008 + 140193036915088 -> 140193036893248 + 140193036915088 [label=TBackward0] + 140193036915760 -> 140193036915088 + 140193036915760 [label=ToCopyBackward0] + 140201394495696 -> 140193036915760 + 140193039421008 [label="encoder.layer.10.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140193039421008 -> 140201394495696 + 140201394495696 [label=AccumulateGrad] + 140193036851856 -> 140193036852048 + 140193036851856 [label=TBackward0] + 140193036852480 -> 140193036851856 + 140193036852480 [label=ToCopyBackward0] + 140193036852912 -> 140193036852480 + 140193039420768 [label="encoder.layer.10.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039420768 -> 140193036852912 + 140193036852912 [label=AccumulateGrad] + 140193036851760 -> 140193036851616 + 140193036851568 -> 140193036851520 + 140193039420528 [label="encoder.layer.10.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039420528 -> 140193036851568 + 140193036851568 [label=AccumulateGrad] + 140193036851328 -> 140193036851520 + 140193039420208 [label="encoder.layer.10.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039420208 -> 140193036851328 + 140193036851328 [label=AccumulateGrad] + 140193036850560 -> 140193036850848 + 140193036850560 [label=TBackward0] + 140193036851376 -> 140193036850560 + 140193036851376 [label=ToCopyBackward0] + 140193036851472 -> 140193036851376 + 140193039417408 [label="encoder.layer.10.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039417408 -> 140193036851472 + 140193036851472 [label=AccumulateGrad] + 140193036850032 -> 140193036850320 + 140193036850032 [label=TBackward0] + 140193036850800 -> 140193036850032 + 140193036850800 [label=ToCopyBackward0] + 140193036851040 -> 140193036850800 + 140193039404784 [label="encoder.layer.10.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039404784 -> 140193036851040 + 140193036851040 [label=AccumulateGrad] + 140193036849936 -> 140193036849888 + 140193036849936 [label=UnsqueezeBackward0] + 140193036850272 -> 140193036849936 + 140193036850272 [label=NativeDropoutBackward0] + 140193036850512 -> 140193036850272 + 140193036850512 [label=ViewBackward0] + 140193036851712 -> 140193036850512 + 140193036851712 [label=AddmmBackward0] + 140193036850704 -> 140193036851712 + 140193036850704 [label=ToCopyBackward0] + 140193036852000 -> 140193036850704 + 140193039403984 [label="encoder.layer.10.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140193039403984 -> 140193036852000 + 140193036852000 [label=AccumulateGrad] + 140193036851088 -> 140193036851712 + 140193036851088 [label=ViewBackward0] + 140193036852192 -> 140193036851088 + 140193036852192 [label=GeluBackward0] + 140193036852720 -> 140193036852192 + 140193036852720 [label=ViewBackward0] + 140193036852432 -> 140193036852720 + 140193036852432 [label=AddmmBackward0] + 140193036890896 -> 140193036852432 + 140193036890896 [label=ToCopyBackward0] + 140193036892528 -> 140193036890896 + 140193039404224 [label="encoder.layer.10.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140193039404224 -> 140193036892528 + 140193036892528 [label=AccumulateGrad] + 140193036890176 -> 140193036852432 + 140193036890176 [label=ViewBackward0] + 140193036891472 -> 140193036890176 + 140193036891472 [label=ToCopyBackward0] + 140193036851280 -> 140193036891472 + 140193036890512 -> 140193036852432 + 140193036890512 [label=TBackward0] + 140193036915280 -> 140193036890512 + 140193036915280 [label=ToCopyBackward0] + 140201394495792 -> 140193036915280 + 140193039404544 [label="encoder.layer.10.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039404544 -> 140201394495792 + 140201394495792 [label=AccumulateGrad] + 140193036850080 -> 140193036851712 + 140193036850080 [label=TBackward0] + 140193036851904 -> 140193036850080 + 140193036851904 [label=ToCopyBackward0] + 140193036891808 -> 140193036851904 + 140193039404304 [label="encoder.layer.10.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039404304 -> 140193036891808 + 140193036891808 [label=AccumulateGrad] + 140193036849600 -> 140193036849552 + 140193036849600 [label=UnsqueezeBackward0] + 140193036891184 -> 140193036849600 + 140193036891184 [label=UnsqueezeBackward0] + 140193036850176 -> 140193036891184 + 140193036850176 [label=MulBackward0] + 140193036850992 -> 140193036850176 + 140193036850992 [label=IndexBackward0] + 140193036851232 -> 140193036850992 + 140193036851232 [label=SoftmaxBackward0] + 140193036852240 -> 140193036851232 + 140193036852240 [label=CatBackward0] + 140201394495744 -> 140193036852240 + 140201394495744 [label=MmBackward0] + 140201394495888 -> 140201394495744 + 140201394495888 [label=MeanBackward1] + 140193036850128 -> 140201394495888 + 140201394495840 -> 140201394495744 + 140201394495840 [label=TBackward0] + 140201394495936 -> 140201394495840 + 140201394495936 [label=ToCopyBackward0] + 140201394496128 -> 140201394495936 + 140193039418048 [label="encoder.layer.10.experts.gate.weight + (1, 768)" fillcolor=lightblue] + 140193039418048 -> 140201394496128 + 140201394496128 [label=AccumulateGrad] + 140201394495552 -> 140193036852240 + 140201394495552 [label=MmBackward0] + 140201394496080 -> 140201394495552 + 140201394496080 [label=MeanBackward1] + 140193036850272 -> 140201394496080 + 140201394495984 -> 140201394495552 + 140201394495984 [label=TBackward0] + 140201394495936 -> 140201394495984 + 140193036803808 -> 140193036803616 + 140193036803808 [label=UnsqueezeBackward0] + 140193036849408 -> 140193036803808 + 140193036849408 [label=SelectBackward0] + 140193036849216 -> 140193036849408 + 140193036849216 [label=NativeDropoutBackward0] + 140193036849696 -> 140193036849216 + 140193036849696 [label=ViewBackward0] + 140193036851664 -> 140193036849696 + 140193036851664 [label=AddmmBackward0] + 140193036849744 -> 140193036851664 + 140193036849744 [label=ToCopyBackward0] + 140201394496032 -> 140193036849744 + 140193039403504 [label="encoder.layer.10.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140193039403504 -> 140201394496032 + 140201394496032 [label=AccumulateGrad] + 140193036849264 -> 140193036851664 + 140193036849264 [label=ViewBackward0] + 140201394496272 -> 140193036849264 + 140201394496272 [label=GeluBackward0] + 140201394496368 -> 140201394496272 + 140201394496368 [label=ViewBackward0] + 140201394496464 -> 140201394496368 + 140201394496464 [label=AddmmBackward0] + 140201394496560 -> 140201394496464 + 140201394496560 [label=ToCopyBackward0] + 140201394496752 -> 140201394496560 + 140193039403744 [label="encoder.layer.10.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140193039403744 -> 140201394496752 + 140201394496752 [label=AccumulateGrad] + 140201394496512 -> 140201394496464 + 140201394496512 [label=ViewBackward0] + 140201394496800 -> 140201394496512 + 140201394496800 [label=ToCopyBackward0] + 140201394496896 -> 140201394496800 + 140201394496896 [label=IndexBackward0] + 140193036803424 -> 140201394496896 + 140201394496176 -> 140201394496464 + 140201394496176 [label=TBackward0] + 140201394496992 -> 140201394496176 + 140201394496992 [label=ToCopyBackward0] + 140201394496704 -> 140201394496992 + 140193039404064 [label="encoder.layer.10.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039404064 -> 140201394496704 + 140201394496704 [label=AccumulateGrad] + 140201394495648 -> 140193036851664 + 140201394495648 [label=TBackward0] + 140201394496416 -> 140201394495648 + 140201394496416 [label=ToCopyBackward0] + 140201394496656 -> 140201394496416 + 140193039403824 [label="encoder.layer.10.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039403824 -> 140201394496656 + 140201394496656 [label=AccumulateGrad] + 140193036803760 -> 140193036803664 + 140193036803760 [label=CatBackward0] + 140193036803952 -> 140193036803760 + 140193036803952 [label=SliceBackward0] + 140193036850464 -> 140193036803952 + 140193036850464 [label=SliceBackward0] + 140201394496608 -> 140193036850464 + 140201394496608 [label=SliceBackward0] + 140193036849456 -> 140201394496608 + 140193036849504 -> 140193036803760 + 140193036849504 [label=UnsqueezeBackward0] + 140193036849840 -> 140193036849504 + 140193036849840 [label=SelectBackward0] + 140193036849216 -> 140193036849840 + 140193036803520 -> 140193036803664 + 140193036803520 [label=CatBackward0] + 140193036849312 -> 140193036803520 + 140193036849312 [label=SliceBackward0] + 140201394496320 -> 140193036849312 + 140201394496320 [label=SliceBackward0] + 140201394496848 -> 140201394496320 + 140201394496848 [label=SliceBackward0] + 140193036849456 -> 140201394496848 + 140201394495600 -> 140193036803520 + 140201394495600 [label=UnsqueezeBackward0] + 140201394497184 -> 140201394495600 + 140201394497184 [label=SelectBackward0] + 140193036849216 -> 140201394497184 + 140193036803856 -> 140193036803664 + 140193036803856 [label=CatBackward0] + 140201394497088 -> 140193036803856 + 140201394497088 [label=SliceBackward0] + 140201394497232 -> 140201394497088 + 140201394497232 [label=SliceBackward0] + 140201394497328 -> 140201394497232 + 140201394497328 [label=SliceBackward0] + 140193036849456 -> 140201394497328 + 140201394497040 -> 140193036803856 + 140201394497040 [label=UnsqueezeBackward0] + 140201394497424 -> 140201394497040 + 140201394497424 [label=SelectBackward0] + 140193036849216 -> 140201394497424 + 140193036803424 -> 140193036803472 + 140193036803232 -> 140193036803136 + 140193039418368 [label="encoder.layer.10.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039418368 -> 140193036803232 + 140193036803232 [label=AccumulateGrad] + 140193036803376 -> 140193036803136 + 140193039418128 [label="encoder.layer.10.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039418128 -> 140193036803376 + 140193036803376 [label=AccumulateGrad] + 140193036803088 -> 140193036800208 + 140193036803088 [label=IndexBackward0] + 140193036803712 -> 140193036803088 + 140193036803712 [label=NativeLayerNormBackward0] + 140193036803328 -> 140193036803712 + 140193036803328 [label=AddBackward0] + 140201394497472 -> 140193036803328 + 140201394497472 [label=NativeDropoutBackward0] + 140201394497616 -> 140201394497472 + 140201394497616 [label=ViewBackward0] + 140201394497712 -> 140201394497616 + 140201394497712 [label=AddmmBackward0] + 140201394497808 -> 140201394497712 + 140201394497808 [label=ToCopyBackward0] + 140201394498000 -> 140201394497808 + 140193039419728 [label="encoder.layer.10.output.dense.bias + (768)" fillcolor=lightblue] + 140193039419728 -> 140201394498000 + 140201394498000 [label=AccumulateGrad] + 140201394497760 -> 140201394497712 + 140201394497760 [label=ViewBackward0] + 140201394498048 -> 140201394497760 + 140201394498048 [label=GeluBackward0] + 140201394498144 -> 140201394498048 + 140201394498144 [label=ViewBackward0] + 140201394498240 -> 140201394498144 + 140201394498240 [label=AddmmBackward0] + 140201394498336 -> 140201394498240 + 140201394498336 [label=ToCopyBackward0] + 140201394498528 -> 140201394498336 + 140193039419968 [label="encoder.layer.10.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039419968 -> 140201394498528 + 140201394498528 [label=AccumulateGrad] + 140201394498288 -> 140201394498240 + 140201394498288 [label=ViewBackward0] + 140201394498576 -> 140201394498288 + 140201394498576 [label=ToCopyBackward0] + 140201394496224 -> 140201394498576 + 140201394496224 [label=SliceBackward0] + 140201394498720 -> 140201394496224 + 140201394498720 [label=SliceBackward0] + 140201394498816 -> 140201394498720 + 140201394498816 [label=SliceBackward0] + 140193036893872 -> 140201394498816 + 140201394497952 -> 140201394498240 + 140201394497952 [label=TBackward0] + 140201394498480 -> 140201394497952 + 140201394498480 [label=ToCopyBackward0] + 140201394498912 -> 140201394498480 + 140193039420288 [label="encoder.layer.10.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039420288 -> 140201394498912 + 140201394498912 [label=AccumulateGrad] + 140201394497520 -> 140201394497712 + 140201394497520 [label=TBackward0] + 140201394498192 -> 140201394497520 + 140201394498192 [label=ToCopyBackward0] + 140201394498672 -> 140201394498192 + 140193039420048 [label="encoder.layer.10.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039420048 -> 140201394498672 + 140201394498672 [label=AccumulateGrad] + 140201394496224 -> 140193036803328 + 140201394497280 -> 140193036803712 + 140193039419808 [label="encoder.layer.10.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039419808 -> 140201394497280 + 140201394497280 [label=AccumulateGrad] + 140201394496944 -> 140193036803712 + 140193039419488 [label="encoder.layer.10.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039419488 -> 140201394496944 + 140201394496944 [label=AccumulateGrad] + 140193036802080 -> 140193036802656 + 140193036802080 [label=TBackward0] + 140193036802848 -> 140193036802080 + 140193036802848 [label=ToCopyBackward0] + 140193036803568 -> 140193036802848 + 140193039418288 [label="encoder.layer.11.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140193039418288 -> 140193036803568 + 140193036803568 [label=AccumulateGrad] + 140193036801984 -> 140193036802128 + 140193036801984 [label=UnsafeViewBackward0] + 140193036802512 -> 140193036801984 + 140193036802512 [label=CloneBackward0] + 140193036802704 -> 140193036802512 + 140193036802704 [label=ExpandBackward0] + 140193036802992 -> 140193036802704 + 140193036802992 [label=TransposeBackward0] + 140193036803040 -> 140193036802992 + 140193036803040 [label=PermuteBackward0] + 140193036802320 -> 140193036803040 + 140193036802320 [label=ViewBackward0] + 140201394497664 -> 140193036802320 + 140201394497664 [label=ViewBackward0] + 140201394497904 -> 140201394497664 + 140201394497904 [label=AddmmBackward0] + 140201394498432 -> 140201394497904 + 140201394498432 [label=ToCopyBackward0] + 140201394498624 -> 140201394498432 + 140193039418848 [label="encoder.layer.11.attention.self.key.bias + (768)" fillcolor=lightblue] + 140193039418848 -> 140201394498624 + 140201394498624 [label=AccumulateGrad] + 140201394498384 -> 140201394497904 + 140201394498384 [label=ViewBackward0] + 140201394498960 -> 140201394498384 + 140201394498960 [label=ToCopyBackward0] + 140193036800208 -> 140201394498960 + 140201394497568 -> 140201394497904 + 140201394497568 [label=TBackward0] + 140201394498096 -> 140201394497568 + 140201394498096 [label=ToCopyBackward0] + 140201394499104 -> 140201394498096 + 140193039418528 [label="encoder.layer.11.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140193039418528 -> 140201394499104 + 140201394499104 [label=AccumulateGrad] + 140193036801264 -> 140193036801024 + 140193036801264 [label=UnsafeViewBackward0] + 140193036801408 -> 140193036801264 + 140193036801408 [label=CloneBackward0] + 140193036801600 -> 140193036801408 + 140193036801600 [label=ExpandBackward0] + 140193036801792 -> 140193036801600 + 140193036801792 [label=PermuteBackward0] + 140193036801360 -> 140193036801792 + 140193036801360 [label=ViewBackward0] + 140193036802608 -> 140193036801360 + 140193036802608 [label=ViewBackward0] + 140193036803280 -> 140193036802608 + 140193036803280 [label=AddmmBackward0] + 140193036801216 -> 140193036803280 + 140193036801216 [label=ToCopyBackward0] + 140201394498864 -> 140193036801216 + 140193039417888 [label="encoder.layer.11.attention.self.value.bias + (768)" fillcolor=lightblue] + 140193039417888 -> 140201394498864 + 140201394498864 [label=AccumulateGrad] + 140201394497136 -> 140193036803280 + 140201394497136 [label=ViewBackward0] + 140201394499200 -> 140201394497136 + 140201394499200 [label=ToCopyBackward0] + 140193036800208 -> 140201394499200 + 140201394497376 -> 140193036803280 + 140201394497376 [label=TBackward0] + 140201394498768 -> 140201394497376 + 140201394498768 [label=ToCopyBackward0] + 140201394499248 -> 140201394498768 + 140193039417808 [label="encoder.layer.11.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140193039417808 -> 140201394499248 + 140201394499248 [label=AccumulateGrad] + 140193036800304 -> 140193036800496 + 140193036800304 [label=TBackward0] + 140193036800976 -> 140193036800304 + 140193036800976 [label=ToCopyBackward0] + 140193036801168 -> 140193036800976 + 140193039417568 [label="encoder.layer.11.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140193039417568 -> 140193036801168 + 140193036801168 [label=AccumulateGrad] + 140193036800208 -> 140193578430272 + 140193578430416 -> 140193578430176 + 140193039403264 [label="encoder.layer.11.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039403264 -> 140193578430416 + 140193578430416 [label=AccumulateGrad] + 140193578429696 -> 140193578430176 + 140193039404944 [label="encoder.layer.11.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039404944 -> 140193578429696 + 140193578429696 [label=AccumulateGrad] + 140193578428928 -> 140193578429216 + 140193578428928 [label=TBackward0] + 140193578429936 -> 140193578428928 + 140193578429936 [label=ToCopyBackward0] + 140193578430032 -> 140193578429936 + 140193039388080 [label="encoder.layer.11.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039388080 -> 140193578430032 + 140193578430032 [label=AccumulateGrad] + 140193578428592 -> 140193578428880 + 140193578428592 [label=TBackward0] + 140193578429360 -> 140193578428592 + 140193578429360 [label=ToCopyBackward0] + 140193578429408 -> 140193578429360 + 140193039387760 [label="encoder.layer.11.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039387760 -> 140193578429408 + 140193578429408 [label=AccumulateGrad] + 140193578428496 -> 140193578428256 + 140193578428496 [label=UnsqueezeBackward0] + 140193578428640 -> 140193578428496 + 140193578428640 [label=NativeDropoutBackward0] + 140193578429072 -> 140193578428640 + 140193578429072 [label=ViewBackward0] + 140193578430224 -> 140193578429072 + 140193578430224 [label=AddmmBackward0] + 140193578429264 -> 140193578430224 + 140193578429264 [label=ToCopyBackward0] + 140193578430368 -> 140193578429264 + 140193039387840 [label="encoder.layer.11.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140193039387840 -> 140193578430368 + 140193578430368 [label=AccumulateGrad] + 140193578429648 -> 140193578430224 + 140193578429648 [label=ViewBackward0] + 140193578429600 -> 140193578429648 + 140193578429600 [label=GeluBackward0] + 140193036800256 -> 140193578429600 + 140193036800256 [label=ViewBackward0] + 140193036800688 -> 140193036800256 + 140193036800688 [label=AddmmBackward0] + 140193036801312 -> 140193036800688 + 140193036801312 [label=ToCopyBackward0] + 140193036801696 -> 140193036801312 + 140193039388240 [label="encoder.layer.11.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140193039388240 -> 140193036801696 + 140193036801696 [label=AccumulateGrad] + 140193036801072 -> 140193036800688 + 140193036801072 [label=ViewBackward0] + 140193036801888 -> 140193036801072 + 140193036801888 [label=ToCopyBackward0] + 140193578429840 -> 140193036801888 + 140193036800160 -> 140193036800688 + 140193036800160 [label=TBackward0] + 140193036800880 -> 140193036800160 + 140193036800880 [label=ToCopyBackward0] + 140193036801504 -> 140193036800880 + 140193039387520 [label="encoder.layer.11.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039387520 -> 140193036801504 + 140193036801504 [label=AccumulateGrad] + 140193578428448 -> 140193578430224 + 140193578428448 [label=TBackward0] + 140193578430320 -> 140193578428448 + 140193578430320 [label=ToCopyBackward0] + 140193036802800 -> 140193578430320 + 140193039387280 [label="encoder.layer.11.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039387280 -> 140193036802800 + 140193036802800 [label=AccumulateGrad] + 140193578426720 -> 140193578427056 + 140193578426720 [label=UnsqueezeBackward0] + 140193578426768 -> 140193578426720 + 140193578426768 [label=UnsqueezeBackward0] + 140193578428832 -> 140193578426768 + 140193578428832 [label=MulBackward0] + 140193578430128 -> 140193578428832 + 140193578430128 [label=IndexBackward0] + 140193578426576 -> 140193578430128 + 140193578426576 [label=SoftmaxBackward0] + 140193036800112 -> 140193578426576 + 140193036800112 [label=CatBackward0] + 140193036800448 -> 140193036800112 + 140193036800448 [label=MmBackward0] + 140201394499296 -> 140193036800448 + 140201394499296 [label=MeanBackward1] + 140193578428688 -> 140201394499296 + 140201394499344 -> 140193036800448 + 140201394499344 [label=TBackward0] + 140201394499152 -> 140201394499344 + 140201394499152 [label=ToCopyBackward0] + 140201394499536 -> 140201394499152 + 140193039401424 [label="encoder.layer.11.experts.gate.weight + (1, 768)" fillcolor=lightblue] + 140193039401424 -> 140201394499536 + 140201394499536 [label=AccumulateGrad] + 140201394499008 -> 140193036800112 + 140201394499008 [label=MmBackward0] + 140201394499488 -> 140201394499008 + 140201394499488 [label=MeanBackward1] + 140193578428640 -> 140201394499488 + 140201394499392 -> 140201394499008 + 140201394499392 [label=TBackward0] + 140201394499152 -> 140201394499392 + 140193578427776 -> 140193578427584 + 140193578427776 [label=UnsqueezeBackward0] + 140193578427008 -> 140193578427776 + 140193578427008 [label=SelectBackward0] + 140193578427872 -> 140193578427008 + 140193578427872 [label=NativeDropoutBackward0] + 140193578428544 -> 140193578427872 + 140193578428544 [label=ViewBackward0] + 140193578429552 -> 140193578428544 + 140193578429552 [label=AddmmBackward0] + 140193036802416 -> 140193578429552 + 140193036802416 [label=ToCopyBackward0] + 140201394499440 -> 140193036802416 + 140193039387360 [label="encoder.layer.11.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140193039387360 -> 140201394499440 + 140201394499440 [label=AccumulateGrad] + 140193036800544 -> 140193578429552 + 140193036800544 [label=ViewBackward0] + 140201394548896 -> 140193036800544 + 140201394548896 [label=GeluBackward0] + 140201394548992 -> 140201394548896 + 140201394548992 [label=ViewBackward0] + 140201394549088 -> 140201394548992 + 140201394549088 [label=AddmmBackward0] + 140201394549184 -> 140201394549088 + 140201394549184 [label=ToCopyBackward0] + 140201394549376 -> 140201394549184 + 140193039387600 [label="encoder.layer.11.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140193039387600 -> 140201394549376 + 140201394549376 [label=AccumulateGrad] + 140201394549136 -> 140201394549088 + 140201394549136 [label=ViewBackward0] + 140201394549424 -> 140201394549136 + 140201394549424 [label=ToCopyBackward0] + 140201394549520 -> 140201394549424 + 140201394549520 [label=IndexBackward0] + 140193578427392 -> 140201394549520 + 140201394548848 -> 140201394549088 + 140201394548848 [label=TBackward0] + 140201394549616 -> 140201394548848 + 140201394549616 [label=ToCopyBackward0] + 140201394549328 -> 140201394549616 + 140193039387040 [label="encoder.layer.11.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140193039387040 -> 140201394549328 + 140201394549328 [label=AccumulateGrad] + 140201394499056 -> 140193578429552 + 140201394499056 [label=TBackward0] + 140201394549040 -> 140201394499056 + 140201394549040 [label=ToCopyBackward0] + 140201394549280 -> 140201394549040 + 140193039386800 [label="encoder.layer.11.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140193039386800 -> 140201394549280 + 140201394549280 [label=AccumulateGrad] + 140193578427728 -> 140193578427632 + 140193578427728 [label=CatBackward0] + 140193578426528 -> 140193578427728 + 140193578426528 [label=SliceBackward0] + 140201394497856 -> 140193578426528 + 140201394497856 [label=SliceBackward0] + 140193578428400 -> 140201394497856 + 140193578428400 [label=SliceBackward0] + 140193578427152 -> 140193578428400 + 140193578427104 -> 140193578427728 + 140193578427104 [label=UnsqueezeBackward0] + 140193578428112 -> 140193578427104 + 140193578428112 [label=SelectBackward0] + 140193578427872 -> 140193578428112 + 140193578427488 -> 140193578427632 + 140193578427488 [label=CatBackward0] + 140193578427920 -> 140193578427488 + 140193578427920 [label=SliceBackward0] + 140201394548944 -> 140193578427920 + 140201394548944 [label=SliceBackward0] + 140201394549472 -> 140201394548944 + 140201394549472 [label=SliceBackward0] + 140193578427152 -> 140201394549472 + 140201394548800 -> 140193578427488 + 140201394548800 [label=UnsqueezeBackward0] + 140201394549808 -> 140201394548800 + 140201394549808 [label=SelectBackward0] + 140193578427872 -> 140201394549808 + 140193578427824 -> 140193578427632 + 140193578427824 [label=CatBackward0] + 140201394549712 -> 140193578427824 + 140201394549712 [label=SliceBackward0] + 140201394549856 -> 140201394549712 + 140201394549856 [label=SliceBackward0] + 140201394549952 -> 140201394549856 + 140201394549952 [label=SliceBackward0] + 140193578427152 -> 140201394549952 + 140201394549664 -> 140193578427824 + 140201394549664 [label=UnsqueezeBackward0] + 140201394550048 -> 140201394549664 + 140201394550048 [label=SelectBackward0] + 140193578427872 -> 140201394550048 + 140193578427392 -> 140193578427440 + 140193578427200 -> 140193578428064 + 140193039401344 [label="encoder.layer.11.expert_ln.weight + (768)" fillcolor=lightblue] + 140193039401344 -> 140193578427200 + 140193578427200 [label=AccumulateGrad] + 140193578427968 -> 140193578428064 + 140193039401104 [label="encoder.layer.11.expert_ln.bias + (768)" fillcolor=lightblue] + 140193039401104 -> 140193578427968 + 140193578427968 [label=AccumulateGrad] + 140193039092080 -> 140193570151248 + 140193039092080 [label=IndexBackward0] + 140193578427536 -> 140193039092080 + 140193578427536 [label=IndexBackward0] + 140193578428064 -> 140193578427536 + 140193039092320 -> 140193039136752 + 140193039092320 [label=AddBackward0] + 140193578427680 -> 140193039092320 + 140193578427680 [label=IndexBackward0] + 140193578428208 -> 140193578427680 + 140193578428208 [label=NativeLayerNormBackward0] + 140201394550000 -> 140193578428208 + 140201394550000 [label=AddBackward0] + 140201394550192 -> 140201394550000 + 140201394550192 [label=NativeDropoutBackward0] + 140201394550336 -> 140201394550192 + 140201394550336 [label=ViewBackward0] + 140201394550432 -> 140201394550336 + 140201394550432 [label=AddmmBackward0] + 140201394550528 -> 140201394550432 + 140201394550528 [label=ToCopyBackward0] + 140201394550720 -> 140201394550528 + 140193039403104 [label="encoder.layer.11.output.dense.bias + (768)" fillcolor=lightblue] + 140193039403104 -> 140201394550720 + 140201394550720 [label=AccumulateGrad] + 140201394550480 -> 140201394550432 + 140201394550480 [label=ViewBackward0] + 140201394550768 -> 140201394550480 + 140201394550768 [label=GeluBackward0] + 140201394550864 -> 140201394550768 + 140201394550864 [label=ViewBackward0] + 140201394550960 -> 140201394550864 + 140201394550960 [label=AddmmBackward0] + 140201394551056 -> 140201394550960 + 140201394551056 [label=ToCopyBackward0] + 140201394551248 -> 140201394551056 + 140193039403344 [label="encoder.layer.11.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140193039403344 -> 140201394551248 + 140201394551248 [label=AccumulateGrad] + 140201394551008 -> 140201394550960 + 140201394551008 [label=ViewBackward0] + 140201394551296 -> 140201394551008 + 140201394551296 [label=ToCopyBackward0] + 140201394550144 -> 140201394551296 + 140201394550144 [label=SliceBackward0] + 140201394551440 -> 140201394550144 + 140201394551440 [label=SliceBackward0] + 140201394551536 -> 140201394551440 + 140201394551536 [label=SliceBackward0] + 140193578430176 -> 140201394551536 + 140201394550672 -> 140201394550960 + 140201394550672 [label=TBackward0] + 140201394551200 -> 140201394550672 + 140201394551200 [label=ToCopyBackward0] + 140201394551632 -> 140201394551200 + 140193039403584 [label="encoder.layer.11.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140193039403584 -> 140201394551632 + 140201394551632 [label=AccumulateGrad] + 140201394550240 -> 140201394550432 + 140201394550240 [label=TBackward0] + 140201394550912 -> 140201394550240 + 140201394550912 [label=ToCopyBackward0] + 140201394551392 -> 140201394550912 + 140193039403024 [label="encoder.layer.11.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140193039403024 -> 140201394551392 + 140201394551392 [label=AccumulateGrad] + 140201394550144 -> 140201394550000 + 140201394549760 -> 140193578428208 + 140193039402784 [label="encoder.layer.11.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140193039402784 -> 140201394549760 + 140201394549760 [label=AccumulateGrad] + 140201394549568 -> 140193578428208 + 140193039402864 [label="encoder.layer.11.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140193039402864 -> 140201394549568 + 140201394549568 [label=AccumulateGrad] + 140193578427296 -> 140193039092320 + 140193578427296 [label=IndexBackward0] + 140201394550096 -> 140193578427296 + 140201394550096 [label=IndexBackward0] + 140201394550384 -> 140201394550096 + 140201394550384 [label=IndexBackward0] + 140193578428208 -> 140201394550384 + 140193039136752 -> 140193037219536 +} diff --git a/Post_Route_Universal_PromptMoE_RawProb_backward_graph.pdf b/Post_Route_Universal_PromptMoE_RawProb_backward_graph.pdf new file mode 100644 index 0000000..ab1ae6d Binary files /dev/null and b/Post_Route_Universal_PromptMoE_RawProb_backward_graph.pdf differ diff --git a/Pre_PromptMoE_RawProb_backward_graph b/Pre_PromptMoE_RawProb_backward_graph new file mode 100644 index 0000000..3a8d029 --- /dev/null +++ b/Pre_PromptMoE_RawProb_backward_graph @@ -0,0 +1,5294 @@ +digraph { + graph [size="739.65,739.65"] + node [align=left fontname=monospace fontsize=10 height=0.2 ranksep=0.1 shape=box style=filled] + 140202223089520 [label=" + (1, 46, 768)" fillcolor=darkolivegreen1] + 140202228657312 [label=CatBackward0] + 140202228615488 -> 140202228657312 + 140202228615488 [label=NativeLayerNormBackward0] + 140202228614096 -> 140202228615488 + 140202228614096 [label=AddBackward0] + 140202223538720 -> 140202228614096 + 140202223538720 [label=NativeDropoutBackward0] + 140202223538912 -> 140202223538720 + 140202223538912 [label=ViewBackward0] + 140202223539008 -> 140202223538912 + 140202223539008 [label=AddmmBackward0] + 140202223539104 -> 140202223539008 + 140202223539104 [label=ToCopyBackward0] + 140202223539296 -> 140202223539104 + 140202228893712 [label="encoder.layer.11.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202228893712 -> 140202223539296 + 140202223539296 [label=AccumulateGrad] + 140202223538864 -> 140202223539008 + 140202223538864 [label=ViewBackward0] + 140202223539152 -> 140202223538864 + 140202223539152 [label=GeluBackward0] + 140202223539248 -> 140202223539152 + 140202223539248 [label=ViewBackward0] + 140202223539680 -> 140202223539248 + 140202223539680 [label=AddmmBackward0] + 140202223539584 -> 140202223539680 + 140202223539584 [label=ToCopyBackward0] + 140202223538528 -> 140202223539584 + 140202228893952 [label="encoder.layer.11.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202228893952 -> 140202223538528 + 140202223538528 [label=AccumulateGrad] + 140202223539440 -> 140202223539680 + 140202223539440 [label=ViewBackward0] + 140202223538288 -> 140202223539440 + 140202223538288 [label=ToCopyBackward0] + 140202223538480 -> 140202223538288 + 140202223538480 [label=SliceBackward0] + 140202223538336 -> 140202223538480 + 140202223538336 [label=SliceBackward0] + 140202223539776 -> 140202223538336 + 140202223539776 [label=SliceBackward0] + 140202223539872 -> 140202223539776 + 140202223539872 [label=SliceBackward0] + 140202223539968 -> 140202223539872 + 140202223539968 [label=SliceBackward0] + 140202223540064 -> 140202223539968 + 140202223540064 [label=NativeLayerNormBackward0] + 140202223540160 -> 140202223540064 + 140202223540160 [label=AddBackward0] + 140202223540352 -> 140202223540160 + 140202223540352 [label=NativeDropoutBackward0] + 140202223540304 -> 140202223540352 + 140202223540304 [label=ViewBackward0] + 140202223540400 -> 140202223540304 + 140202223540400 [label=AddmmBackward0] + 140202223540496 -> 140202223540400 + 140202223540496 [label=ToCopyBackward0] + 140202223540688 -> 140202223540496 + 140202228904080 [label="encoder.layer.11.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228904080 -> 140202223540688 + 140202223540688 [label=AccumulateGrad] + 140202223540640 -> 140202223540400 + 140202223540640 [label=ViewBackward0] + 140202223540928 -> 140202223540640 + 140202223540928 [label=ViewBackward0] + 140202223541024 -> 140202223540928 + 140202223541024 [label=CloneBackward0] + 140202223541120 -> 140202223541024 + 140202223541120 [label=PermuteBackward0] + 140202223541216 -> 140202223541120 + 140202223541216 [label=UnsafeViewBackward0] + 140202223541312 -> 140202223541216 + 140202223541312 [label=BmmBackward0] + 140202223541408 -> 140202223541312 + 140202223541408 [label=ReshapeAliasBackward0] + 140202223541504 -> 140202223541408 + 140202223541504 [label=ExpandBackward0] + 140202223541600 -> 140202223541504 + 140202223541600 [label=ToCopyBackward0] + 140202223541792 -> 140202223541600 + 140202223541792 [label=NativeDropoutBackward0] + 140202223541984 -> 140202223541792 + 140202223541984 [label=SoftmaxBackward0] + 140202223542080 -> 140202223541984 + 140202223542080 [label=AddBackward0] + 140202223541264 -> 140202223542080 + 140202223541264 [label=DivBackward0] + 140202223575296 -> 140202223541264 + 140202223575296 [label=UnsafeViewBackward0] + 140202223575392 -> 140202223575296 + 140202223575392 [label=BmmBackward0] + 140202223575584 -> 140202223575392 + 140202223575584 [label=ReshapeAliasBackward0] + 140202223575968 -> 140202223575584 + 140202223575968 [label=ExpandBackward0] + 140202223576160 -> 140202223575968 + 140202223576160 [label=PermuteBackward0] + 140202223576208 -> 140202223576160 + 140202223576208 [label=ViewBackward0] + 140202223576448 -> 140202223576208 + 140202223576448 [label=ViewBackward0] + 140202223576640 -> 140202223576448 + 140202223576640 [label=AddmmBackward0] + 140202223576688 -> 140202223576640 + 140202223576688 [label=ToCopyBackward0] + 140202223577120 -> 140202223576688 + 140202228906560 [label="encoder.layer.11.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202228906560 -> 140202223577120 + 140202223577120 [label=AccumulateGrad] + 140202223576544 -> 140202223576640 + 140202223576544 [label=ViewBackward0] + 140202223577024 -> 140202223576544 + 140202223577024 [label=ToCopyBackward0] + 140202223540112 -> 140202223577024 + 140202223540112 [label=CatBackward0] + 140202223577408 -> 140202223540112 + 140202223577408 [label=NativeLayerNormBackward0] + 140202223577504 -> 140202223577408 + 140202223577504 [label=AddBackward0] + 140202223577792 -> 140202223577504 + 140202223577792 [label=SumBackward1] + 140202223578128 -> 140202223577792 + 140202223578128 [label=MulBackward0] + 140202223578368 -> 140202223578128 + 140202223578368 [label=PermuteBackward0] + 140202223578464 -> 140202223578368 + 140202223578464 [label=CatBackward0] + 140202223578656 -> 140202223578464 + 140202223578656 [label=UnsqueezeBackward0] + 140202223578944 -> 140202223578656 + 140202223578944 [label=NativeDropoutBackward0] + 140202223578752 -> 140202223578944 + 140202223578752 [label=ViewBackward0] + 140202223079536 -> 140202223578752 + 140202223079536 [label=AddmmBackward0] + 140202223079776 -> 140202223079536 + 140202223079776 [label=ToCopyBackward0] + 140202223080064 -> 140202223079776 + 140202228905360 [label="encoder.layer.10.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140202228905360 -> 140202223080064 + 140202223080064 [label=AccumulateGrad] + 140202223079872 -> 140202223079536 + 140202223079872 [label=ViewBackward0] + 140202223080352 -> 140202223079872 + 140202223080352 [label=GeluBackward0] + 140202223080400 -> 140202223080352 + 140202223080400 [label=ViewBackward0] + 140202223080640 -> 140202223080400 + 140202223080640 [label=AddmmBackward0] + 140202223080832 -> 140202223080640 + 140202223080832 [label=ToCopyBackward0] + 140202223081120 -> 140202223080832 + 140202228905280 [label="encoder.layer.10.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140202228905280 -> 140202223081120 + 140202223081120 [label=AccumulateGrad] + 140202223080544 -> 140202223080640 + 140202223080544 [label=ViewBackward0] + 140202223081024 -> 140202223080544 + 140202223081024 [label=ToCopyBackward0] + 140202223577888 -> 140202223081024 + 140202223577888 [label=SliceBackward0] + 140202223081360 -> 140202223577888 + 140202223081360 [label=SliceBackward0] + 140202223081600 -> 140202223081360 + 140202223081600 [label=NativeLayerNormBackward0] + 140202223081792 -> 140202223081600 + 140202223081792 [label=AddBackward0] + 140202223082080 -> 140202223081792 + 140202223082080 [label=NativeDropoutBackward0] + 140202223082176 -> 140202223082080 + 140202223082176 [label=ViewBackward0] + 140202223082368 -> 140202223082176 + 140202223082368 [label=AddmmBackward0] + 140202223082464 -> 140202223082368 + 140202223082464 [label=ToCopyBackward0] + 140202223082848 -> 140202223082464 + 140202228924880 [label="encoder.layer.10.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228924880 -> 140202223082848 + 140202223082848 [label=AccumulateGrad] + 140202223082560 -> 140202223082368 + 140202223082560 [label=ViewBackward0] + 140202223083040 -> 140202223082560 + 140202223083040 [label=ViewBackward0] + 140202223083232 -> 140202223083040 + 140202223083232 [label=CloneBackward0] + 140202223083280 -> 140202223083232 + 140202223083280 [label=PermuteBackward0] + 140202223083424 -> 140202223083280 + 140202223083424 [label=UnsafeViewBackward0] + 140202223082800 -> 140202223083424 + 140202223082800 [label=BmmBackward0] + 140202223108400 -> 140202223082800 + 140202223108400 [label=ReshapeAliasBackward0] + 140202223108544 -> 140202223108400 + 140202223108544 [label=ExpandBackward0] + 140202223108736 -> 140202223108544 + 140202223108736 [label=ToCopyBackward0] + 140202223108928 -> 140202223108736 + 140202223108928 [label=NativeDropoutBackward0] + 140202223109024 -> 140202223108928 + 140202223109024 [label=SoftmaxBackward0] + 140202223109216 -> 140202223109024 + 140202223109216 [label=AddBackward0] + 140202223109408 -> 140202223109216 + 140202223109408 [label=DivBackward0] + 140202223109504 -> 140202223109408 + 140202223109504 [label=UnsafeViewBackward0] + 140202223109696 -> 140202223109504 + 140202223109696 [label=BmmBackward0] + 140202223109888 -> 140202223109696 + 140202223109888 [label=ReshapeAliasBackward0] + 140202223110272 -> 140202223109888 + 140202223110272 [label=ExpandBackward0] + 140202223110320 -> 140202223110272 + 140202223110320 [label=PermuteBackward0] + 140202223110560 -> 140202223110320 + 140202223110560 [label=ViewBackward0] + 140202223110752 -> 140202223110560 + 140202223110752 [label=ViewBackward0] + 140202223110800 -> 140202223110752 + 140202223110800 [label=AddmmBackward0] + 140202223111040 -> 140202223110800 + 140202223111040 [label=ToCopyBackward0] + 140202223111280 -> 140202223111040 + 140202228925600 [label="encoder.layer.10.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140202228925600 -> 140202223111280 + 140202223111280 [label=AccumulateGrad] + 140202223110848 -> 140202223110800 + 140202223110848 [label=ViewBackward0] + 140202223111328 -> 140202223110848 + 140202223111328 [label=ToCopyBackward0] + 140202223081888 -> 140202223111328 + 140202223081888 [label=SliceBackward0] + 140202223111712 -> 140202223081888 + 140202223111712 [label=SliceBackward0] + 140202223111760 -> 140202223111712 + 140202223111760 [label=SliceBackward0] + 140202223112000 -> 140202223111760 + 140202223112000 [label=NativeLayerNormBackward0] + 140202223112096 -> 140202223112000 + 140202223112096 [label=AddBackward0] + 140202223137120 -> 140202223112096 + 140202223137120 [label=NativeDropoutBackward0] + 140202223137216 -> 140202223137120 + 140202223137216 [label=ViewBackward0] + 140202223137408 -> 140202223137216 + 140202223137408 [label=AddmmBackward0] + 140202223137504 -> 140202223137408 + 140202223137504 [label=ToCopyBackward0] + 140202223137888 -> 140202223137504 + 140202228926080 [label="encoder.layer.10.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228926080 -> 140202223137888 + 140202223137888 [label=AccumulateGrad] + 140202223137600 -> 140202223137408 + 140202223137600 [label=ViewBackward0] + 140202223138080 -> 140202223137600 + 140202223138080 [label=ViewBackward0] + 140202223138272 -> 140202223138080 + 140202223138272 [label=CloneBackward0] + 140202223138320 -> 140202223138272 + 140202223138320 [label=PermuteBackward0] + 140202223138560 -> 140202223138320 + 140202223138560 [label=UnsafeViewBackward0] + 140202223138752 -> 140202223138560 + 140202223138752 [label=BmmBackward0] + 140202223138800 -> 140202223138752 + 140202223138800 [label=ReshapeAliasBackward0] + 140202223138944 -> 140202223138800 + 140202223138944 [label=ExpandBackward0] + 140202223139136 -> 140202223138944 + 140202223139136 [label=ToCopyBackward0] + 140202223139328 -> 140202223139136 + 140202223139328 [label=NativeDropoutBackward0] + 140202223139424 -> 140202223139328 + 140202223139424 [label=SoftmaxBackward0] + 140202223139616 -> 140202223139424 + 140202223139616 [label=AddBackward0] + 140202223139808 -> 140202223139616 + 140202223139808 [label=DivBackward0] + 140202223139904 -> 140202223139808 + 140202223139904 [label=UnsafeViewBackward0] + 140202223140096 -> 140202223139904 + 140202223140096 [label=BmmBackward0] + 140202223140288 -> 140202223140096 + 140202223140288 [label=ReshapeAliasBackward0] + 140202223140672 -> 140202223140288 + 140202223140672 [label=ExpandBackward0] + 140202223140720 -> 140202223140672 + 140202223140720 [label=PermuteBackward0] + 140202223140768 -> 140202223140720 + 140202223140768 [label=ViewBackward0] + 140202223169888 -> 140202223140768 + 140202223169888 [label=ViewBackward0] + 140202223169936 -> 140202223169888 + 140202223169936 [label=AddmmBackward0] + 140202223170176 -> 140202223169936 + 140202223170176 [label=ToCopyBackward0] + 140202223170416 -> 140202223170176 + 140202228926800 [label="encoder.layer.10.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202228926800 -> 140202223170416 + 140202223170416 [label=AccumulateGrad] + 140202223169984 -> 140202223169936 + 140202223169984 [label=ViewBackward0] + 140202223170464 -> 140202223169984 + 140202223170464 [label=ToCopyBackward0] + 140202223136928 -> 140202223170464 + 140202223136928 [label=CatBackward0] + 140202223170848 -> 140202223136928 + 140202223170848 [label=NativeLayerNormBackward0] + 140202223170944 -> 140202223170848 + 140202223170944 [label=AddBackward0] + 140202223171232 -> 140202223170944 + 140202223171232 [label=NativeDropoutBackward0] + 140202223171616 -> 140202223171232 + 140202223171616 [label=ViewBackward0] + 140202223171808 -> 140202223171616 + 140202223171808 [label=AddmmBackward0] + 140202223171856 -> 140202223171808 + 140202223171856 [label=ToCopyBackward0] + 140202223172288 -> 140202223171856 + 140202228927280 [label="encoder.layer.9.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202228927280 -> 140202223172288 + 140202223172288 [label=AccumulateGrad] + 140202223171712 -> 140202223171808 + 140202223171712 [label=ViewBackward0] + 140202223172192 -> 140202223171712 + 140202223172192 [label=GeluBackward0] + 140202223172384 -> 140202223172192 + 140202223172384 [label=ViewBackward0] + 140202223172480 -> 140202223172384 + 140202223172480 [label=AddmmBackward0] + 140202223172672 -> 140202223172480 + 140202223172672 [label=ToCopyBackward0] + 140202223172960 -> 140202223172672 + 140202228927520 [label="encoder.layer.9.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202228927520 -> 140202223172960 + 140202223172960 [label=AccumulateGrad] + 140202223172768 -> 140202223172480 + 140202223172768 [label=ViewBackward0] + 140202223173248 -> 140202223172768 + 140202223173248 [label=ToCopyBackward0] + 140202223171328 -> 140202223173248 + 140202223171328 [label=SliceBackward0] + 140202223173344 -> 140202223171328 + 140202223173344 [label=SliceBackward0] + 140202223173440 -> 140202223173344 + 140202223173440 [label=SliceBackward0] + 140202223172864 -> 140202223173440 + 140202223172864 [label=SliceBackward0] + 140202223194368 -> 140202223172864 + 140202223194368 [label=SliceBackward0] + 140202223194464 -> 140202223194368 + 140202223194464 [label=NativeLayerNormBackward0] + 140202223194656 -> 140202223194464 + 140202223194656 [label=AddBackward0] + 140202223194944 -> 140202223194656 + 140202223194944 [label=NativeDropoutBackward0] + 140202223195280 -> 140202223194944 + 140202223195280 [label=ViewBackward0] + 140202223195520 -> 140202223195280 + 140202223195520 [label=AddmmBackward0] + 140202223195712 -> 140202223195520 + 140202223195712 [label=ToCopyBackward0] + 140202223196000 -> 140202223195712 + 140202228933472 [label="encoder.layer.9.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228933472 -> 140202223196000 + 140202223196000 [label=AccumulateGrad] + 140202223195424 -> 140202223195520 + 140202223195424 [label=ViewBackward0] + 140202223195904 -> 140202223195424 + 140202223195904 [label=ViewBackward0] + 140202223196096 -> 140202223195904 + 140202223196096 [label=CloneBackward0] + 140202223196288 -> 140202223196096 + 140202223196288 [label=PermuteBackward0] + 140202223196384 -> 140202223196288 + 140202223196384 [label=UnsafeViewBackward0] + 140202223196576 -> 140202223196384 + 140202223196576 [label=BmmBackward0] + 140202223196768 -> 140202223196576 + 140202223196768 [label=ReshapeAliasBackward0] + 140202223197152 -> 140202223196768 + 140202223197152 [label=ExpandBackward0] + 140202223197200 -> 140202223197152 + 140202223197200 [label=ToCopyBackward0] + 140202223197440 -> 140202223197200 + 140202223197440 [label=NativeDropoutBackward0] + 140202223197632 -> 140202223197440 + 140202223197632 [label=SoftmaxBackward0] + 140202223197680 -> 140202223197632 + 140202223197680 [label=AddBackward0] + 140202223197920 -> 140202223197680 + 140202223197920 [label=DivBackward0] + 140202223198112 -> 140202223197920 + 140202223198112 [label=UnsafeViewBackward0] + 140202223198016 -> 140202223198112 + 140202223198016 [label=BmmBackward0] + 140202223227136 -> 140202223198016 + 140202223227136 [label=ReshapeAliasBackward0] + 140202223227232 -> 140202223227136 + 140202223227232 [label=ExpandBackward0] + 140202223227424 -> 140202223227232 + 140202223227424 [label=PermuteBackward0] + 140202223227520 -> 140202223227424 + 140202223227520 [label=ViewBackward0] + 140202223227712 -> 140202223227520 + 140202223227712 [label=ViewBackward0] + 140202223227904 -> 140202223227712 + 140202223227904 [label=AddmmBackward0] + 140202223228000 -> 140202223227904 + 140202223228000 [label=ToCopyBackward0] + 140202223228384 -> 140202223228000 + 140202228936032 [label="encoder.layer.9.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202228936032 -> 140202223228384 + 140202223228384 [label=AccumulateGrad] + 140202223228096 -> 140202223227904 + 140202223228096 [label=ViewBackward0] + 140202223228576 -> 140202223228096 + 140202223228576 [label=ToCopyBackward0] + 140202223195040 -> 140202223228576 + 140202223195040 [label=CatBackward0] + 140202223228672 -> 140202223195040 + 140202223228672 [label=NativeLayerNormBackward0] + 140202223229056 -> 140202223228672 + 140202223229056 [label=AddBackward0] + 140202223229296 -> 140202223229056 + 140202223229296 [label=SumBackward1] + 140202223229440 -> 140202223229296 + 140202223229440 [label=MulBackward0] + 140202223229632 -> 140202223229440 + 140202223229632 [label=PermuteBackward0] + 140202223230016 -> 140202223229632 + 140202223230016 [label=CatBackward0] + 140202223230208 -> 140202223230016 + 140202223230208 [label=UnsqueezeBackward0] + 140202223230496 -> 140202223230208 + 140202223230496 [label=NativeDropoutBackward0] + 140202223230688 -> 140202223230496 + 140202223230688 [label=ViewBackward0] + 140202223230736 -> 140202223230688 + 140202223230736 [label=AddmmBackward0] + 140202223230880 -> 140202223230736 + 140202223230880 [label=ToCopyBackward0] + 140202223247664 -> 140202223230880 + 140202228934832 [label="encoder.layer.8.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140202228934832 -> 140202223247664 + 140202223247664 [label=AccumulateGrad] + 140202223230784 -> 140202223230736 + 140202223230784 [label=ViewBackward0] + 140202223247712 -> 140202223230784 + 140202223247712 [label=GeluBackward0] + 140202223247808 -> 140202223247712 + 140202223247808 [label=ViewBackward0] + 140202223248000 -> 140202223247808 + 140202223248000 [label=AddmmBackward0] + 140202223248192 -> 140202223248000 + 140202223248192 [label=ToCopyBackward0] + 140202223248480 -> 140202223248192 + 140202228935152 [label="encoder.layer.8.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140202228935152 -> 140202223248480 + 140202223248480 [label=AccumulateGrad] + 140202223248144 -> 140202223248000 + 140202223248144 [label=ViewBackward0] + 140202223248624 -> 140202223248144 + 140202223248624 [label=ToCopyBackward0] + 140202223229152 -> 140202223248624 + 140202223229152 [label=SliceBackward0] + 140202223248768 -> 140202223229152 + 140202223248768 [label=SliceBackward0] + 140202223248960 -> 140202223248768 + 140202223248960 [label=NativeLayerNormBackward0] + 140202223249152 -> 140202223248960 + 140202223249152 [label=AddBackward0] + 140202223249440 -> 140202223249152 + 140202223249440 [label=NativeDropoutBackward0] + 140202223249824 -> 140202223249440 + 140202223249824 [label=ViewBackward0] + 140202223250016 -> 140202223249824 + 140202223250016 [label=AddmmBackward0] + 140202223250064 -> 140202223250016 + 140202223250064 [label=ToCopyBackward0] + 140202223250496 -> 140202223250064 + 140202228950656 [label="encoder.layer.8.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228950656 -> 140202223250496 + 140202223250496 [label=AccumulateGrad] + 140202223249920 -> 140202223250016 + 140202223249920 [label=ViewBackward0] + 140202223250400 -> 140202223249920 + 140202223250400 [label=ViewBackward0] + 140202223250592 -> 140202223250400 + 140202223250592 [label=CloneBackward0] + 140202223250688 -> 140202223250592 + 140202223250688 [label=PermuteBackward0] + 140202223250976 -> 140202223250688 + 140202223250976 [label=UnsafeViewBackward0] + 140202223251264 -> 140202223250976 + 140202223251264 [label=BmmBackward0] + 140202223251360 -> 140202223251264 + 140202223251360 [label=ReshapeAliasBackward0] + 140202223284384 -> 140202223251360 + 140202223284384 [label=ExpandBackward0] + 140202223284480 -> 140202223284384 + 140202223284480 [label=ToCopyBackward0] + 140202223284672 -> 140202223284480 + 140202223284672 [label=NativeDropoutBackward0] + 140202223284864 -> 140202223284672 + 140202223284864 [label=SoftmaxBackward0] + 140202223284960 -> 140202223284864 + 140202223284960 [label=AddBackward0] + 140202223285152 -> 140202223284960 + 140202223285152 [label=DivBackward0] + 140202223285344 -> 140202223285152 + 140202223285344 [label=UnsafeViewBackward0] + 140202223285440 -> 140202223285344 + 140202223285440 [label=BmmBackward0] + 140202223285632 -> 140202223285440 + 140202223285632 [label=ReshapeAliasBackward0] + 140202223286016 -> 140202223285632 + 140202223286016 [label=ExpandBackward0] + 140202223286208 -> 140202223286016 + 140202223286208 [label=PermuteBackward0] + 140202223286256 -> 140202223286208 + 140202223286256 [label=ViewBackward0] + 140202223286496 -> 140202223286256 + 140202223286496 [label=ViewBackward0] + 140202223286688 -> 140202223286496 + 140202223286688 [label=AddmmBackward0] + 140202223286736 -> 140202223286688 + 140202223286736 [label=ToCopyBackward0] + 140202223287168 -> 140202223286736 + 140202228951376 [label="encoder.layer.8.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140202228951376 -> 140202223287168 + 140202223287168 [label=AccumulateGrad] + 140202223286592 -> 140202223286688 + 140202223286592 [label=ViewBackward0] + 140202223287072 -> 140202223286592 + 140202223287072 [label=ToCopyBackward0] + 140202223249536 -> 140202223287072 + 140202223249536 [label=SliceBackward0] + 140202223287456 -> 140202223249536 + 140202223287456 [label=SliceBackward0] + 140202223287648 -> 140202223287456 + 140202223287648 [label=SliceBackward0] + 140202223287696 -> 140202223287648 + 140202223287696 [label=NativeLayerNormBackward0] + 140202223287936 -> 140202223287696 + 140202223287936 [label=AddBackward0] + 140202223288176 -> 140202223287936 + 140202223288176 [label=NativeDropoutBackward0] + 140202223288224 -> 140202223288176 + 140202223288224 [label=ViewBackward0] + 140202223313152 -> 140202223288224 + 140202223313152 [label=AddmmBackward0] + 140202223313344 -> 140202223313152 + 140202223313344 [label=ToCopyBackward0] + 140202223313632 -> 140202223313344 + 140202228951856 [label="encoder.layer.8.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228951856 -> 140202223313632 + 140202223313632 [label=AccumulateGrad] + 140202223313296 -> 140202223313152 + 140202223313296 [label=ViewBackward0] + 140202223313776 -> 140202223313296 + 140202223313776 [label=ViewBackward0] + 140202223314016 -> 140202223313776 + 140202223314016 [label=CloneBackward0] + 140202223314208 -> 140202223314016 + 140202223314208 [label=PermuteBackward0] + 140202223314256 -> 140202223314208 + 140202223314256 [label=UnsafeViewBackward0] + 140202223314496 -> 140202223314256 + 140202223314496 [label=BmmBackward0] + 140202223314688 -> 140202223314496 + 140202223314688 [label=ReshapeAliasBackward0] + 140202223314784 -> 140202223314688 + 140202223314784 [label=ExpandBackward0] + 140202223314880 -> 140202223314784 + 140202223314880 [label=ToCopyBackward0] + 140202223315072 -> 140202223314880 + 140202223315072 [label=NativeDropoutBackward0] + 140202223315264 -> 140202223315072 + 140202223315264 [label=SoftmaxBackward0] + 140202223315360 -> 140202223315264 + 140202223315360 [label=AddBackward0] + 140202223315552 -> 140202223315360 + 140202223315552 [label=DivBackward0] + 140202223315744 -> 140202223315552 + 140202223315744 [label=UnsafeViewBackward0] + 140202223315840 -> 140202223315744 + 140202223315840 [label=BmmBackward0] + 140202223316032 -> 140202223315840 + 140202223316032 [label=ReshapeAliasBackward0] + 140202223316416 -> 140202223316032 + 140202223316416 [label=ExpandBackward0] + 140202223316608 -> 140202223316416 + 140202223316608 [label=PermuteBackward0] + 140202223316656 -> 140202223316608 + 140202223316656 [label=ViewBackward0] + 140202223316896 -> 140202223316656 + 140202223316896 [label=ViewBackward0] + 140202223316800 -> 140202223316896 + 140202223316800 [label=AddmmBackward0] + 140202222817488 -> 140202223316800 + 140202222817488 [label=ToCopyBackward0] + 140202222817920 -> 140202222817488 + 140202228952576 [label="encoder.layer.8.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202228952576 -> 140202222817920 + 140202222817920 [label=AccumulateGrad] + 140202222817344 -> 140202223316800 + 140202222817344 [label=ViewBackward0] + 140202222817824 -> 140202222817344 + 140202222817824 [label=ToCopyBackward0] + 140202223288032 -> 140202222817824 + 140202223288032 [label=CatBackward0] + 140202222818208 -> 140202223288032 + 140202222818208 [label=NativeLayerNormBackward0] + 140202222818304 -> 140202222818208 + 140202222818304 [label=AddBackward0] + 140202222818592 -> 140202222818304 + 140202222818592 [label=NativeDropoutBackward0] + 140202222818928 -> 140202222818592 + 140202222818928 [label=ViewBackward0] + 140202222819168 -> 140202222818928 + 140202222819168 [label=AddmmBackward0] + 140202222819360 -> 140202222819168 + 140202222819360 [label=ToCopyBackward0] + 140202222819648 -> 140202222819360 + 140202228952976 [label="encoder.layer.7.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202228952976 -> 140202222819648 + 140202222819648 [label=AccumulateGrad] + 140202222819072 -> 140202222819168 + 140202222819072 [label=ViewBackward0] + 140202222819552 -> 140202222819072 + 140202222819552 [label=GeluBackward0] + 140202222819744 -> 140202222819552 + 140202222819744 [label=ViewBackward0] + 140202222819936 -> 140202222819744 + 140202222819936 [label=AddmmBackward0] + 140202222820032 -> 140202222819936 + 140202222820032 [label=ToCopyBackward0] + 140202222820416 -> 140202222820032 + 140202228965680 [label="encoder.layer.7.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202228965680 -> 140202222820416 + 140202222820416 [label=AccumulateGrad] + 140202222820128 -> 140202222819936 + 140202222820128 [label=ViewBackward0] + 140202222820608 -> 140202222820128 + 140202222820608 [label=ToCopyBackward0] + 140202222818688 -> 140202222820608 + 140202222818688 [label=SliceBackward0] + 140202222820704 -> 140202222818688 + 140202222820704 [label=SliceBackward0] + 140202222820896 -> 140202222820704 + 140202222820896 [label=SliceBackward0] + 140202222820992 -> 140202222820896 + 140202222820992 [label=SliceBackward0] + 140202222821184 -> 140202222820992 + 140202222821184 [label=SliceBackward0] + 140202222820224 -> 140202222821184 + 140202222820224 [label=NativeLayerNormBackward0] + 140202222841968 -> 140202222820224 + 140202222841968 [label=AddBackward0] + 140202222842400 -> 140202222841968 + 140202222842400 [label=NativeDropoutBackward0] + 140202222842784 -> 140202222842400 + 140202222842784 [label=ViewBackward0] + 140202222842832 -> 140202222842784 + 140202222842832 [label=AddmmBackward0] + 140202222843072 -> 140202222842832 + 140202222843072 [label=ToCopyBackward0] + 140202222843312 -> 140202222843072 + 140202228967040 [label="encoder.layer.7.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228967040 -> 140202222843312 + 140202222843312 [label=AccumulateGrad] + 140202222842880 -> 140202222842832 + 140202222842880 [label=ViewBackward0] + 140202222843360 -> 140202222842880 + 140202222843360 [label=ViewBackward0] + 140202222843456 -> 140202222843360 + 140202222843456 [label=CloneBackward0] + 140202222843648 -> 140202222843456 + 140202222843648 [label=PermuteBackward0] + 140202222843840 -> 140202222843648 + 140202222843840 [label=UnsafeViewBackward0] + 140202222843936 -> 140202222843840 + 140202222843936 [label=BmmBackward0] + 140202222844128 -> 140202222843936 + 140202222844128 [label=ReshapeAliasBackward0] + 140202222844512 -> 140202222844128 + 140202222844512 [label=ExpandBackward0] + 140202222844704 -> 140202222844512 + 140202222844704 [label=ToCopyBackward0] + 140202222844752 -> 140202222844704 + 140202222844752 [label=NativeDropoutBackward0] + 140202222844992 -> 140202222844752 + 140202222844992 [label=SoftmaxBackward0] + 140202222845184 -> 140202222844992 + 140202222845184 [label=AddBackward0] + 140202222845232 -> 140202222845184 + 140202222845232 [label=DivBackward0] + 140202222845472 -> 140202222845232 + 140202222845472 [label=UnsafeViewBackward0] + 140202222845664 -> 140202222845472 + 140202222845664 [label=BmmBackward0] + 140202222845712 -> 140202222845664 + 140202222845712 [label=ReshapeAliasBackward0] + 140202222845856 -> 140202222845712 + 140202222845856 [label=ExpandBackward0] + 140202222870688 -> 140202222845856 + 140202222870688 [label=PermuteBackward0] + 140202222870880 -> 140202222870688 + 140202222870880 [label=ViewBackward0] + 140202222870976 -> 140202222870880 + 140202222870976 [label=ViewBackward0] + 140202222871168 -> 140202222870976 + 140202222871168 [label=AddmmBackward0] + 140202222871360 -> 140202222871168 + 140202222871360 [label=ToCopyBackward0] + 140202222871648 -> 140202222871360 + 140202228982304 [label="encoder.layer.7.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202228982304 -> 140202222871648 + 140202222871648 [label=AccumulateGrad] + 140202222871312 -> 140202222871168 + 140202222871312 [label=ViewBackward0] + 140202222871792 -> 140202222871312 + 140202222871792 [label=ToCopyBackward0] + 140202222842352 -> 140202222871792 + 140202222842352 [label=CatBackward0] + 140202222871936 -> 140202222842352 + 140202222871936 [label=NativeLayerNormBackward0] + 140202222872272 -> 140202222871936 + 140202222872272 [label=AddBackward0] + 140202222872704 -> 140202222872272 + 140202222872704 [label=SumBackward1] + 140202222872800 -> 140202222872704 + 140202222872800 [label=MulBackward0] + 140202222872896 -> 140202222872800 + 140202222872896 [label=PermuteBackward0] + 140202222873232 -> 140202222872896 + 140202222873232 [label=CatBackward0] + 140202222873472 -> 140202222873232 + 140202222873472 [label=UnsqueezeBackward0] + 140202222873712 -> 140202222873472 + 140202222873712 [label=NativeDropoutBackward0] + 140202222873952 -> 140202222873712 + 140202222873952 [label=ViewBackward0] + 140202222874144 -> 140202222873952 + 140202222874144 [label=AddmmBackward0] + 140202222874192 -> 140202222874144 + 140202222874192 [label=ToCopyBackward0] + 140202222874528 -> 140202222874192 + 140202228968800 [label="encoder.layer.6.experts.experts.0.dense2.bias + (768)" fillcolor=lightblue] + 140202228968800 -> 140202222874528 + 140202222874528 [label=AccumulateGrad] + 140202222874048 -> 140202222874144 + 140202222874048 [label=ViewBackward0] + 140202222874432 -> 140202222874048 + 140202222874432 [label=GeluBackward0] + 140202222903456 -> 140202222874432 + 140202222903456 [label=ViewBackward0] + 140202222903552 -> 140202222903456 + 140202222903552 [label=AddmmBackward0] + 140202222903744 -> 140202222903552 + 140202222903744 [label=ToCopyBackward0] + 140202222904032 -> 140202222903744 + 140202228968720 [label="encoder.layer.6.experts.experts.0.dense1.bias + (3072)" fillcolor=lightblue] + 140202228968720 -> 140202222904032 + 140202222904032 [label=AccumulateGrad] + 140202222903840 -> 140202222903552 + 140202222903840 [label=ViewBackward0] + 140202222904320 -> 140202222903840 + 140202222904320 [label=ToCopyBackward0] + 140202222872416 -> 140202222904320 + 140202222872416 [label=SliceBackward0] + 140202222904416 -> 140202222872416 + 140202222904416 [label=SliceBackward0] + 140202222904512 -> 140202222904416 + 140202222904512 [label=NativeLayerNormBackward0] + 140202222904896 -> 140202222904512 + 140202222904896 [label=AddBackward0] + 140202222905184 -> 140202222904896 + 140202222905184 [label=NativeDropoutBackward0] + 140202222905280 -> 140202222905184 + 140202222905280 [label=ViewBackward0] + 140202222905328 -> 140202222905280 + 140202222905328 [label=AddmmBackward0] + 140202222905568 -> 140202222905328 + 140202222905568 [label=ToCopyBackward0] + 140202222905808 -> 140202222905568 + 140202228984224 [label="encoder.layer.6.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228984224 -> 140202222905808 + 140202222905808 [label=AccumulateGrad] + 140202222905664 -> 140202222905328 + 140202222905664 [label=ViewBackward0] + 140202222906144 -> 140202222905664 + 140202222906144 [label=ViewBackward0] + 140202222906336 -> 140202222906144 + 140202222906336 [label=CloneBackward0] + 140202222906432 -> 140202222906336 + 140202222906432 [label=PermuteBackward0] + 140202222906624 -> 140202222906432 + 140202222906624 [label=UnsafeViewBackward0] + 140202222906816 -> 140202222906624 + 140202222906816 [label=BmmBackward0] + 140202222906912 -> 140202222906816 + 140202222906912 [label=ReshapeAliasBackward0] + 140202222907008 -> 140202222906912 + 140202222907008 [label=ExpandBackward0] + 140202222907200 -> 140202222907008 + 140202222907200 [label=ToCopyBackward0] + 140202222907248 -> 140202222907200 + 140202222907248 [label=NativeDropoutBackward0] + 140202222932128 -> 140202222907248 + 140202222932128 [label=SoftmaxBackward0] + 140202222932320 -> 140202222932128 + 140202222932320 [label=AddBackward0] + 140202222932368 -> 140202222932320 + 140202222932368 [label=DivBackward0] + 140202222932608 -> 140202222932368 + 140202222932608 [label=UnsafeViewBackward0] + 140202222932800 -> 140202222932608 + 140202222932800 [label=BmmBackward0] + 140202222932848 -> 140202222932800 + 140202222932848 [label=ReshapeAliasBackward0] + 140202222933376 -> 140202222932848 + 140202222933376 [label=ExpandBackward0] + 140202222933472 -> 140202222933376 + 140202222933472 [label=PermuteBackward0] + 140202222933664 -> 140202222933472 + 140202222933664 [label=ViewBackward0] + 140202222933856 -> 140202222933664 + 140202222933856 [label=ViewBackward0] + 140202222933952 -> 140202222933856 + 140202222933952 [label=AddmmBackward0] + 140202222934144 -> 140202222933952 + 140202222934144 [label=ToCopyBackward0] + 140202222934432 -> 140202222934144 + 140202228984944 [label="encoder.layer.6.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140202228984944 -> 140202222934432 + 140202222934432 [label=AccumulateGrad] + 140202222933808 -> 140202222933952 + 140202222933808 [label=ViewBackward0] + 140202222934288 -> 140202222933808 + 140202222934288 [label=ToCopyBackward0] + 140202222904848 -> 140202222934288 + 140202222904848 [label=SliceBackward0] + 140202222934816 -> 140202222904848 + 140202222934816 [label=SliceBackward0] + 140202222934912 -> 140202222934816 + 140202222934912 [label=SliceBackward0] + 140202222935104 -> 140202222934912 + 140202222935104 [label=NativeLayerNormBackward0] + 140202222935296 -> 140202222935104 + 140202222935296 [label=AddBackward0] + 140202222935584 -> 140202222935296 + 140202222935584 [label=NativeDropoutBackward0] + 140202222935680 -> 140202222935584 + 140202222935680 [label=ViewBackward0] + 140202222935728 -> 140202222935680 + 140202222935728 [label=AddmmBackward0] + 140202222935968 -> 140202222935728 + 140202222935968 [label=ToCopyBackward0] + 140202222960848 -> 140202222935968 + 140202228985424 [label="encoder.layer.6.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228985424 -> 140202222960848 + 140202222960848 [label=AccumulateGrad] + 140202222935488 -> 140202222935728 + 140202222935488 [label=ViewBackward0] + 140202222961184 -> 140202222935488 + 140202222961184 [label=ViewBackward0] + 140202222961376 -> 140202222961184 + 140202222961376 [label=CloneBackward0] + 140202222961472 -> 140202222961376 + 140202222961472 [label=PermuteBackward0] + 140202222961664 -> 140202222961472 + 140202222961664 [label=UnsafeViewBackward0] + 140202222961856 -> 140202222961664 + 140202222961856 [label=BmmBackward0] + 140202222961952 -> 140202222961856 + 140202222961952 [label=ReshapeAliasBackward0] + 140202222962048 -> 140202222961952 + 140202222962048 [label=ExpandBackward0] + 140202222962240 -> 140202222962048 + 140202222962240 [label=ToCopyBackward0] + 140202222962288 -> 140202222962240 + 140202222962288 [label=NativeDropoutBackward0] + 140202222962528 -> 140202222962288 + 140202222962528 [label=SoftmaxBackward0] + 140202222962720 -> 140202222962528 + 140202222962720 [label=AddBackward0] + 140202222962768 -> 140202222962720 + 140202222962768 [label=DivBackward0] + 140202222963008 -> 140202222962768 + 140202222963008 [label=UnsafeViewBackward0] + 140202222963200 -> 140202222963008 + 140202222963200 [label=BmmBackward0] + 140202222963248 -> 140202222963200 + 140202222963248 [label=ReshapeAliasBackward0] + 140202222963776 -> 140202222963248 + 140202222963776 [label=ExpandBackward0] + 140202222963872 -> 140202222963776 + 140202222963872 [label=PermuteBackward0] + 140202222964064 -> 140202222963872 + 140202222964064 [label=ViewBackward0] + 140202222964256 -> 140202222964064 + 140202222964256 [label=ViewBackward0] + 140202222964352 -> 140202222964256 + 140202222964352 [label=AddmmBackward0] + 140202222964544 -> 140202222964352 + 140202222964544 [label=ToCopyBackward0] + 140202222964640 -> 140202222964544 + 140202228986240 [label="encoder.layer.6.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202228986240 -> 140202222964640 + 140202222964640 [label=AccumulateGrad] + 140202222964208 -> 140202222964352 + 140202222964208 [label=ViewBackward0] + 140202222988064 -> 140202222964208 + 140202222988064 [label=ToCopyBackward0] + 140202222935248 -> 140202222988064 + 140202222935248 [label=CatBackward0] + 140202222988736 -> 140202222935248 + 140202222988736 [label=NativeLayerNormBackward0] + 140202222985280 -> 140202222988736 + 140202222985280 [label=AddBackward0] + 140202222985472 -> 140202222985280 + 140202222985472 [label=NativeDropoutBackward0] + 140202222985856 -> 140202222985472 + 140202222985856 [label=ViewBackward0] + 140202222986048 -> 140202222985856 + 140202222986048 [label=AddmmBackward0] + 140202222986240 -> 140202222986048 + 140202222986240 [label=ToCopyBackward0] + 140202222987680 -> 140202222986240 + 140202228986720 [label="encoder.layer.5.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202228986720 -> 140202222987680 + 140202222987680 [label=AccumulateGrad] + 140202222985952 -> 140202222986048 + 140202222985952 [label=ViewBackward0] + 140202222986432 -> 140202222985952 + 140202222986432 [label=GeluBackward0] + 140202222986624 -> 140202222986432 + 140202222986624 [label=ViewBackward0] + 140202222986672 -> 140202222986624 + 140202222986672 [label=AddmmBackward0] + 140202222986912 -> 140202222986672 + 140202222986912 [label=ToCopyBackward0] + 140202222989072 -> 140202222986912 + 140202228986960 [label="encoder.layer.5.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202228986960 -> 140202222989072 + 140202222989072 [label=AccumulateGrad] + 140202222986816 -> 140202222986672 + 140202222986816 [label=ViewBackward0] + 140202222988832 -> 140202222986816 + 140202222988832 [label=ToCopyBackward0] + 140202222985568 -> 140202222988832 + 140202222985568 [label=SliceBackward0] + 140202222987632 -> 140202222985568 + 140202222987632 [label=SliceBackward0] + 140202222989216 -> 140202222987632 + 140202222989216 [label=SliceBackward0] + 140202222987872 -> 140202222989216 + 140202222987872 [label=SliceBackward0] + 140202222987968 -> 140202222987872 + 140202222987968 [label=SliceBackward0] + 140202222988352 -> 140202222987968 + 140202222988352 [label=NativeLayerNormBackward0] + 140202222987584 -> 140202222988352 + 140202222987584 [label=AddBackward0] + 140202224191520 -> 140202222987584 + 140202224191520 [label=NativeDropoutBackward0] + 140202224191280 -> 140202224191520 + 140202224191280 [label=ViewBackward0] + 140202224191184 -> 140202224191280 + 140202224191184 [label=AddmmBackward0] + 140202224191088 -> 140202224191184 + 140202224191088 [label=ToCopyBackward0] + 140202224190896 -> 140202224191088 + 140202228988880 [label="encoder.layer.5.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202228988880 -> 140202224190896 + 140202224190896 [label=AccumulateGrad] + 140202224191232 -> 140202224191184 + 140202224191232 [label=ViewBackward0] + 140202224190944 -> 140202224191232 + 140202224190944 [label=ViewBackward0] + 140202224190848 -> 140202224190944 + 140202224190848 [label=CloneBackward0] + 140202224190752 -> 140202224190848 + 140202224190752 [label=PermuteBackward0] + 140202224190656 -> 140202224190752 + 140202224190656 [label=UnsafeViewBackward0] + 140202224190560 -> 140202224190656 + 140202224190560 [label=BmmBackward0] + 140202224190464 -> 140202224190560 + 140202224190464 [label=ReshapeAliasBackward0] + 140202224190224 -> 140202224190464 + 140202224190224 [label=ExpandBackward0] + 140202224190128 -> 140202224190224 + 140202224190128 [label=ToCopyBackward0] + 140202224190032 -> 140202224190128 + 140202224190032 [label=NativeDropoutBackward0] + 140202224189936 -> 140202224190032 + 140202224189936 [label=SoftmaxBackward0] + 140202224189840 -> 140202224189936 + 140202224189840 [label=AddBackward0] + 140202224189744 -> 140202224189840 + 140202224189744 [label=DivBackward0] + 140202224189648 -> 140202224189744 + 140202224189648 [label=UnsafeViewBackward0] + 140202224189552 -> 140202224189648 + 140202224189552 [label=BmmBackward0] + 140202224189504 -> 140202224189552 + 140202224189504 [label=ReshapeAliasBackward0] + 140202224191808 -> 140202224189504 + 140202224191808 [label=ExpandBackward0] + 140202224191904 -> 140202224191808 + 140202224191904 [label=PermuteBackward0] + 140202224192000 -> 140202224191904 + 140202224192000 [label=ViewBackward0] + 140202224192096 -> 140202224192000 + 140202224192096 [label=ViewBackward0] + 140202224192192 -> 140202224192096 + 140202224192192 [label=AddmmBackward0] + 140202224192288 -> 140202224192192 + 140202224192288 [label=ToCopyBackward0] + 140202224192480 -> 140202224192288 + 140202228989600 [label="encoder.layer.5.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202228989600 -> 140202224192480 + 140202224192480 [label=AccumulateGrad] + 140202224192240 -> 140202224192192 + 140202224192240 [label=ViewBackward0] + 140202224192576 -> 140202224192240 + 140202224192576 [label=ToCopyBackward0] + 140202224191472 -> 140202224192576 + 140202224191472 [label=CatBackward0] + 140202224192720 -> 140202224191472 + 140202224192720 [label=NativeLayerNormBackward0] + 140202224192864 -> 140202224192720 + 140202224192864 [label=AddBackward0] + 140202224193056 -> 140202224192864 + 140202224193056 [label=NativeDropoutBackward0] + 140202224193200 -> 140202224193056 + 140202224193200 [label=ViewBackward0] + 140202224193296 -> 140202224193200 + 140202224193296 [label=AddmmBackward0] + 140202224193392 -> 140202224193296 + 140202224193392 [label=ToCopyBackward0] + 140202224193488 -> 140202224193392 + 140202229010656 [label="encoder.layer.4.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202229010656 -> 140202224193488 + 140202224193488 [label=AccumulateGrad] + 140202224193344 -> 140202224193296 + 140202224193344 [label=ViewBackward0] + 140210811924640 -> 140202224193344 + 140210811924640 [label=GeluBackward0] + 140210811924736 -> 140210811924640 + 140210811924736 [label=ViewBackward0] + 140210811924832 -> 140210811924736 + 140210811924832 [label=AddmmBackward0] + 140210811924928 -> 140210811924832 + 140210811924928 [label=ToCopyBackward0] + 140210811925120 -> 140210811924928 + 140202229010896 [label="encoder.layer.4.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202229010896 -> 140210811925120 + 140210811925120 [label=AccumulateGrad] + 140210811924880 -> 140210811924832 + 140210811924880 [label=ViewBackward0] + 140210811925168 -> 140210811924880 + 140210811925168 [label=ToCopyBackward0] + 140202224193008 -> 140210811925168 + 140202224193008 [label=SliceBackward0] + 140210811925312 -> 140202224193008 + 140210811925312 [label=SliceBackward0] + 140210811925408 -> 140210811925312 + 140210811925408 [label=NativeLayerNormBackward0] + 140210811925504 -> 140210811925408 + 140210811925504 [label=AddBackward0] + 140210811925696 -> 140210811925504 + 140210811925696 [label=NativeDropoutBackward0] + 140210811925840 -> 140210811925696 + 140210811925840 [label=ViewBackward0] + 140210811925936 -> 140210811925840 + 140210811925936 [label=AddmmBackward0] + 140210811926032 -> 140210811925936 + 140210811926032 [label=ToCopyBackward0] + 140210811926224 -> 140210811926032 + 140202229012816 [label="encoder.layer.4.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140202229012816 -> 140210811926224 + 140210811926224 [label=AccumulateGrad] + 140210811925984 -> 140210811925936 + 140210811925984 [label=ViewBackward0] + 140210811926272 -> 140210811925984 + 140210811926272 [label=ViewBackward0] + 140210811926368 -> 140210811926272 + 140210811926368 [label=CloneBackward0] + 140210811926464 -> 140210811926368 + 140210811926464 [label=PermuteBackward0] + 140210811926560 -> 140210811926464 + 140210811926560 [label=UnsafeViewBackward0] + 140210811926656 -> 140210811926560 + 140210811926656 [label=BmmBackward0] + 140210811926752 -> 140210811926656 + 140210811926752 [label=ReshapeAliasBackward0] + 140210811926896 -> 140210811926752 + 140210811926896 [label=ExpandBackward0] + 140210811926992 -> 140210811926896 + 140210811926992 [label=ToCopyBackward0] + 140210811927088 -> 140210811926992 + 140210811927088 [label=NativeDropoutBackward0] + 140210811927184 -> 140210811927088 + 140210811927184 [label=SoftmaxBackward0] + 140210811927280 -> 140210811927184 + 140210811927280 [label=AddBackward0] + 140210811927376 -> 140210811927280 + 140210811927376 [label=DivBackward0] + 140210811927472 -> 140210811927376 + 140210811927472 [label=UnsafeViewBackward0] + 140210811927568 -> 140210811927472 + 140210811927568 [label=BmmBackward0] + 140210811927664 -> 140210811927568 + 140210811927664 [label=ReshapeAliasBackward0] + 140210811927808 -> 140210811927664 + 140210811927808 [label=ExpandBackward0] + 140210811927904 -> 140210811927808 + 140210811927904 [label=PermuteBackward0] + 140210811928000 -> 140210811927904 + 140210811928000 [label=ViewBackward0] + 140210811928096 -> 140210811928000 + 140210811928096 [label=ViewBackward0] + 140210811928192 -> 140210811928096 + 140210811928192 [label=AddmmBackward0] + 140210811928288 -> 140210811928192 + 140210811928288 [label=ToCopyBackward0] + 140210811928480 -> 140210811928288 + 140202229013536 [label="encoder.layer.4.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140202229013536 -> 140210811928480 + 140210811928480 [label=AccumulateGrad] + 140210811928240 -> 140210811928192 + 140210811928240 [label=ViewBackward0] + 140210811928384 -> 140210811928240 + 140210811928384 [label=ToCopyBackward0] + 140210811925648 -> 140210811928384 + 140210811925648 [label=SliceBackward0] + 140210811941024 -> 140210811925648 + 140210811941024 [label=SliceBackward0] + 140210811941120 -> 140210811941024 + 140210811941120 [label=SliceBackward0] + 140210811941216 -> 140210811941120 + 140210811941216 [label=NativeLayerNormBackward0] + 140210811941312 -> 140210811941216 + 140210811941312 [label=AddBackward0] + 140210811941504 -> 140210811941312 + 140210811941504 [label=NativeDropoutBackward0] + 140210811941648 -> 140210811941504 + 140210811941648 [label=ViewBackward0] + 140210811941744 -> 140210811941648 + 140210811941744 [label=AddmmBackward0] + 140210811941840 -> 140210811941744 + 140210811941840 [label=ToCopyBackward0] + 140210811942032 -> 140210811941840 + 140202229014016 [label="encoder.layer.4.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202229014016 -> 140210811942032 + 140210811942032 [label=AccumulateGrad] + 140210811941792 -> 140210811941744 + 140210811941792 [label=ViewBackward0] + 140210811942080 -> 140210811941792 + 140210811942080 [label=ViewBackward0] + 140210811942176 -> 140210811942080 + 140210811942176 [label=CloneBackward0] + 140210811942272 -> 140210811942176 + 140210811942272 [label=PermuteBackward0] + 140210811942368 -> 140210811942272 + 140210811942368 [label=UnsafeViewBackward0] + 140210811942464 -> 140210811942368 + 140210811942464 [label=BmmBackward0] + 140210811942560 -> 140210811942464 + 140210811942560 [label=ReshapeAliasBackward0] + 140210811942704 -> 140210811942560 + 140210811942704 [label=ExpandBackward0] + 140210811942800 -> 140210811942704 + 140210811942800 [label=ToCopyBackward0] + 140210811942896 -> 140210811942800 + 140210811942896 [label=NativeDropoutBackward0] + 140210811942992 -> 140210811942896 + 140210811942992 [label=SoftmaxBackward0] + 140210811943088 -> 140210811942992 + 140210811943088 [label=AddBackward0] + 140210811943184 -> 140210811943088 + 140210811943184 [label=DivBackward0] + 140210811943280 -> 140210811943184 + 140210811943280 [label=UnsafeViewBackward0] + 140210811943376 -> 140210811943280 + 140210811943376 [label=BmmBackward0] + 140210811943472 -> 140210811943376 + 140210811943472 [label=ReshapeAliasBackward0] + 140210811943616 -> 140210811943472 + 140210811943616 [label=ExpandBackward0] + 140210811943712 -> 140210811943616 + 140210811943712 [label=PermuteBackward0] + 140210811943808 -> 140210811943712 + 140210811943808 [label=ViewBackward0] + 140210811943904 -> 140210811943808 + 140210811943904 [label=ViewBackward0] + 140210811944000 -> 140210811943904 + 140210811944000 [label=AddmmBackward0] + 140210811944096 -> 140210811944000 + 140210811944096 [label=ToCopyBackward0] + 140210811944288 -> 140210811944096 + 140202229023024 [label="encoder.layer.4.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202229023024 -> 140210811944288 + 140210811944288 [label=AccumulateGrad] + 140210811944048 -> 140210811944000 + 140210811944048 [label=ViewBackward0] + 140210811944336 -> 140210811944048 + 140210811944336 [label=ToCopyBackward0] + 140210811941456 -> 140210811944336 + 140210811941456 [label=CatBackward0] + 140210811944480 -> 140210811941456 + 140210811944480 [label=NativeLayerNormBackward0] + 140210811944624 -> 140210811944480 + 140210811944624 [label=AddBackward0] + 140210811944816 -> 140210811944624 + 140210811944816 [label=NativeDropoutBackward0] + 140210811944912 -> 140210811944816 + 140210811944912 [label=ViewBackward0] + 140210811957408 -> 140210811944912 + 140210811957408 [label=AddmmBackward0] + 140210811957504 -> 140210811957408 + 140210811957504 [label=ToCopyBackward0] + 140210811957696 -> 140210811957504 + 140202229023504 [label="encoder.layer.3.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202229023504 -> 140210811957696 + 140210811957696 [label=AccumulateGrad] + 140210811957456 -> 140210811957408 + 140210811957456 [label=ViewBackward0] + 140210811957744 -> 140210811957456 + 140210811957744 [label=GeluBackward0] + 140210811957840 -> 140210811957744 + 140210811957840 [label=ViewBackward0] + 140210811957936 -> 140210811957840 + 140210811957936 [label=AddmmBackward0] + 140210811958032 -> 140210811957936 + 140210811958032 [label=ToCopyBackward0] + 140210811958224 -> 140210811958032 + 140202229023744 [label="encoder.layer.3.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202229023744 -> 140210811958224 + 140210811958224 [label=AccumulateGrad] + 140210811957984 -> 140210811957936 + 140210811957984 [label=ViewBackward0] + 140210811958272 -> 140210811957984 + 140210811958272 [label=ToCopyBackward0] + 140210811944768 -> 140210811958272 + 140210811944768 [label=SliceBackward0] + 140210811958416 -> 140210811944768 + 140210811958416 [label=SliceBackward0] + 140210811958512 -> 140210811958416 + 140210811958512 [label=SliceBackward0] + 140210811958608 -> 140210811958512 + 140210811958608 [label=SliceBackward0] + 140210811958704 -> 140210811958608 + 140210811958704 [label=SliceBackward0] + 140210811958800 -> 140210811958704 + 140210811958800 [label=NativeLayerNormBackward0] + 140210811958896 -> 140210811958800 + 140210811958896 [label=AddBackward0] + 140210811959088 -> 140210811958896 + 140210811959088 [label=NativeDropoutBackward0] + 140210811959232 -> 140210811959088 + 140210811959232 [label=ViewBackward0] + 140210811959328 -> 140210811959232 + 140210811959328 [label=AddmmBackward0] + 140210811959424 -> 140210811959328 + 140210811959424 [label=ToCopyBackward0] + 140210811959616 -> 140210811959424 + 140202229025664 [label="encoder.layer.3.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202229025664 -> 140210811959616 + 140210811959616 [label=AccumulateGrad] + 140210811959376 -> 140210811959328 + 140210811959376 [label=ViewBackward0] + 140210811959664 -> 140210811959376 + 140210811959664 [label=ViewBackward0] + 140210811959760 -> 140210811959664 + 140210811959760 [label=CloneBackward0] + 140210811959856 -> 140210811959760 + 140210811959856 [label=PermuteBackward0] + 140210811959952 -> 140210811959856 + 140210811959952 [label=UnsafeViewBackward0] + 140210811960048 -> 140210811959952 + 140210811960048 [label=BmmBackward0] + 140210811960144 -> 140210811960048 + 140210811960144 [label=ReshapeAliasBackward0] + 140210811960288 -> 140210811960144 + 140210811960288 [label=ExpandBackward0] + 140210811960384 -> 140210811960288 + 140210811960384 [label=ToCopyBackward0] + 140210811960480 -> 140210811960384 + 140210811960480 [label=NativeDropoutBackward0] + 140210811960576 -> 140210811960480 + 140210811960576 [label=SoftmaxBackward0] + 140210811960672 -> 140210811960576 + 140210811960672 [label=AddBackward0] + 140210811960768 -> 140210811960672 + 140210811960768 [label=DivBackward0] + 140210811960864 -> 140210811960768 + 140210811960864 [label=UnsafeViewBackward0] + 140210811960960 -> 140210811960864 + 140210811960960 [label=BmmBackward0] + 140210811961056 -> 140210811960960 + 140210811961056 [label=ReshapeAliasBackward0] + 140210811961200 -> 140210811961056 + 140210811961200 [label=ExpandBackward0] + 140210811961296 -> 140210811961200 + 140210811961296 [label=PermuteBackward0] + 140210811961104 -> 140210811961296 + 140210811961104 [label=ViewBackward0] + 140210811973840 -> 140210811961104 + 140210811973840 [label=ViewBackward0] + 140210811973936 -> 140210811973840 + 140210811973936 [label=AddmmBackward0] + 140210811974032 -> 140210811973936 + 140210811974032 [label=ToCopyBackward0] + 140210811974224 -> 140210811974032 + 140202229026384 [label="encoder.layer.3.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202229026384 -> 140210811974224 + 140210811974224 [label=AccumulateGrad] + 140210811973984 -> 140210811973936 + 140210811973984 [label=ViewBackward0] + 140210811974272 -> 140210811973984 + 140210811974272 [label=ToCopyBackward0] + 140210811959040 -> 140210811974272 + 140210811959040 [label=CatBackward0] + 140210811974416 -> 140210811959040 + 140210811974416 [label=NativeLayerNormBackward0] + 140210811974560 -> 140210811974416 + 140210811974560 [label=AddBackward0] + 140210811974752 -> 140210811974560 + 140210811974752 [label=NativeDropoutBackward0] + 140210811974896 -> 140210811974752 + 140210811974896 [label=ViewBackward0] + 140210811974992 -> 140210811974896 + 140210811974992 [label=AddmmBackward0] + 140210811975088 -> 140210811974992 + 140210811975088 [label=ToCopyBackward0] + 140210811975280 -> 140210811975088 + 140202229039248 [label="encoder.layer.2.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202229039248 -> 140210811975280 + 140210811975280 [label=AccumulateGrad] + 140210811975040 -> 140210811974992 + 140210811975040 [label=ViewBackward0] + 140210811975328 -> 140210811975040 + 140210811975328 [label=GeluBackward0] + 140210811975424 -> 140210811975328 + 140210811975424 [label=ViewBackward0] + 140210811975520 -> 140210811975424 + 140210811975520 [label=AddmmBackward0] + 140210811975616 -> 140210811975520 + 140210811975616 [label=ToCopyBackward0] + 140210811975808 -> 140210811975616 + 140202229039488 [label="encoder.layer.2.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202229039488 -> 140210811975808 + 140210811975808 [label=AccumulateGrad] + 140210811975568 -> 140210811975520 + 140210811975568 [label=ViewBackward0] + 140210811975856 -> 140210811975568 + 140210811975856 [label=ToCopyBackward0] + 140210811974704 -> 140210811975856 + 140210811974704 [label=SliceBackward0] + 140210811976000 -> 140210811974704 + 140210811976000 [label=SliceBackward0] + 140210811976096 -> 140210811976000 + 140210811976096 [label=NativeLayerNormBackward0] + 140210811976192 -> 140210811976096 + 140210811976192 [label=AddBackward0] + 140210811976384 -> 140210811976192 + 140210811976384 [label=NativeDropoutBackward0] + 140210811976528 -> 140210811976384 + 140210811976528 [label=ViewBackward0] + 140210811976624 -> 140210811976528 + 140210811976624 [label=AddmmBackward0] + 140210811976720 -> 140210811976624 + 140210811976720 [label=ToCopyBackward0] + 140210811976912 -> 140210811976720 + 140202229041408 [label="encoder.layer.2.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140202229041408 -> 140210811976912 + 140210811976912 [label=AccumulateGrad] + 140210811976672 -> 140210811976624 + 140210811976672 [label=ViewBackward0] + 140210811976960 -> 140210811976672 + 140210811976960 [label=ViewBackward0] + 140210811977056 -> 140210811976960 + 140210811977056 [label=CloneBackward0] + 140210811977152 -> 140210811977056 + 140210811977152 [label=PermuteBackward0] + 140210811977248 -> 140210811977152 + 140210811977248 [label=UnsafeViewBackward0] + 140210811977344 -> 140210811977248 + 140210811977344 [label=BmmBackward0] + 140210811977440 -> 140210811977344 + 140210811977440 [label=ReshapeAliasBackward0] + 140210811977584 -> 140210811977440 + 140210811977584 [label=ExpandBackward0] + 140210811977680 -> 140210811977584 + 140210811977680 [label=ToCopyBackward0] + 140210811977488 -> 140210811977680 + 140210811977488 [label=NativeDropoutBackward0] + 140210811994320 -> 140210811977488 + 140210811994320 [label=SoftmaxBackward0] + 140210811994416 -> 140210811994320 + 140210811994416 [label=AddBackward0] + 140210811994512 -> 140210811994416 + 140210811994512 [label=DivBackward0] + 140210811994608 -> 140210811994512 + 140210811994608 [label=UnsafeViewBackward0] + 140210811994704 -> 140210811994608 + 140210811994704 [label=BmmBackward0] + 140210811994800 -> 140210811994704 + 140210811994800 [label=ReshapeAliasBackward0] + 140210811994944 -> 140210811994800 + 140210811994944 [label=ExpandBackward0] + 140210811995040 -> 140210811994944 + 140210811995040 [label=PermuteBackward0] + 140210811995136 -> 140210811995040 + 140210811995136 [label=ViewBackward0] + 140210811995232 -> 140210811995136 + 140210811995232 [label=ViewBackward0] + 140210811995328 -> 140210811995232 + 140210811995328 [label=AddmmBackward0] + 140210811995424 -> 140210811995328 + 140210811995424 [label=ToCopyBackward0] + 140210811995616 -> 140210811995424 + 140202229042128 [label="encoder.layer.2.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140202229042128 -> 140210811995616 + 140210811995616 [label=AccumulateGrad] + 140210811995376 -> 140210811995328 + 140210811995376 [label=ViewBackward0] + 140210811995664 -> 140210811995376 + 140210811995664 [label=ToCopyBackward0] + 140210811976336 -> 140210811995664 + 140210811976336 [label=SliceBackward0] + 140210811995808 -> 140210811976336 + 140210811995808 [label=SliceBackward0] + 140210811995904 -> 140210811995808 + 140210811995904 [label=SliceBackward0] + 140210811996000 -> 140210811995904 + 140210811996000 [label=NativeLayerNormBackward0] + 140210811996096 -> 140210811996000 + 140210811996096 [label=AddBackward0] + 140210811996288 -> 140210811996096 + 140210811996288 [label=NativeDropoutBackward0] + 140210811996432 -> 140210811996288 + 140210811996432 [label=ViewBackward0] + 140210811996528 -> 140210811996432 + 140210811996528 [label=AddmmBackward0] + 140210811996624 -> 140210811996528 + 140210811996624 [label=ToCopyBackward0] + 140210811996816 -> 140210811996624 + 140202229042608 [label="encoder.layer.2.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202229042608 -> 140210811996816 + 140210811996816 [label=AccumulateGrad] + 140210811996576 -> 140210811996528 + 140210811996576 [label=ViewBackward0] + 140210811996864 -> 140210811996576 + 140210811996864 [label=ViewBackward0] + 140210811996960 -> 140210811996864 + 140210811996960 [label=CloneBackward0] + 140210811997056 -> 140210811996960 + 140210811997056 [label=PermuteBackward0] + 140210811997152 -> 140210811997056 + 140210811997152 [label=UnsafeViewBackward0] + 140210811997248 -> 140210811997152 + 140210811997248 [label=BmmBackward0] + 140210811997344 -> 140210811997248 + 140210811997344 [label=ReshapeAliasBackward0] + 140210811997488 -> 140210811997344 + 140210811997488 [label=ExpandBackward0] + 140210811997584 -> 140210811997488 + 140210811997584 [label=ToCopyBackward0] + 140210811997680 -> 140210811997584 + 140210811997680 [label=NativeDropoutBackward0] + 140210811997776 -> 140210811997680 + 140210811997776 [label=SoftmaxBackward0] + 140210811997872 -> 140210811997776 + 140210811997872 [label=AddBackward0] + 140210811997968 -> 140210811997872 + 140210811997968 [label=DivBackward0] + 140210811998064 -> 140210811997968 + 140210811998064 [label=UnsafeViewBackward0] + 140210811998160 -> 140210811998064 + 140210811998160 [label=BmmBackward0] + 140210811997392 -> 140210811998160 + 140210811997392 [label=ReshapeAliasBackward0] + 140210812006656 -> 140210811997392 + 140210812006656 [label=ExpandBackward0] + 140210812006752 -> 140210812006656 + 140210812006752 [label=PermuteBackward0] + 140210812006848 -> 140210812006752 + 140210812006848 [label=ViewBackward0] + 140210812006944 -> 140210812006848 + 140210812006944 [label=ViewBackward0] + 140210812007040 -> 140210812006944 + 140210812007040 [label=AddmmBackward0] + 140210812007136 -> 140210812007040 + 140210812007136 [label=ToCopyBackward0] + 140210812007328 -> 140210812007136 + 140202229047520 [label="encoder.layer.2.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202229047520 -> 140210812007328 + 140210812007328 [label=AccumulateGrad] + 140210812007088 -> 140210812007040 + 140210812007088 [label=ViewBackward0] + 140210812007376 -> 140210812007088 + 140210812007376 [label=ToCopyBackward0] + 140210811996240 -> 140210812007376 + 140210811996240 [label=CatBackward0] + 140210812007520 -> 140210811996240 + 140210812007520 [label=NativeLayerNormBackward0] + 140210812007664 -> 140210812007520 + 140210812007664 [label=AddBackward0] + 140210812007856 -> 140210812007664 + 140210812007856 [label=NativeDropoutBackward0] + 140210812008000 -> 140210812007856 + 140210812008000 [label=ViewBackward0] + 140210812008096 -> 140210812008000 + 140210812008096 [label=AddmmBackward0] + 140210812008192 -> 140210812008096 + 140210812008192 [label=ToCopyBackward0] + 140210812008384 -> 140210812008192 + 140202229048000 [label="encoder.layer.1.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202229048000 -> 140210812008384 + 140210812008384 [label=AccumulateGrad] + 140210812008144 -> 140210812008096 + 140210812008144 [label=ViewBackward0] + 140210812008432 -> 140210812008144 + 140210812008432 [label=GeluBackward0] + 140210812008528 -> 140210812008432 + 140210812008528 [label=ViewBackward0] + 140210812008624 -> 140210812008528 + 140210812008624 [label=AddmmBackward0] + 140210812008720 -> 140210812008624 + 140210812008720 [label=ToCopyBackward0] + 140210812008912 -> 140210812008720 + 140202229048240 [label="encoder.layer.1.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202229048240 -> 140210812008912 + 140210812008912 [label=AccumulateGrad] + 140210812008672 -> 140210812008624 + 140210812008672 [label=ViewBackward0] + 140210812008960 -> 140210812008672 + 140210812008960 [label=ToCopyBackward0] + 140210812007808 -> 140210812008960 + 140210812007808 [label=SliceBackward0] + 140210812009104 -> 140210812007808 + 140210812009104 [label=SliceBackward0] + 140210812009200 -> 140210812009104 + 140210812009200 [label=SliceBackward0] + 140210812009296 -> 140210812009200 + 140210812009296 [label=SliceBackward0] + 140210812009392 -> 140210812009296 + 140210812009392 [label=SliceBackward0] + 140210812009488 -> 140210812009392 + 140210812009488 [label=NativeLayerNormBackward0] + 140210812009584 -> 140210812009488 + 140210812009584 [label=AddBackward0] + 140210812009776 -> 140210812009584 + 140210812009776 [label=NativeDropoutBackward0] + 140210812009920 -> 140210812009776 + 140210812009920 [label=ViewBackward0] + 140210812010016 -> 140210812009920 + 140210812010016 [label=AddmmBackward0] + 140210812010112 -> 140210812010016 + 140210812010112 [label=ToCopyBackward0] + 140210812010304 -> 140210812010112 + 140202229050160 [label="encoder.layer.1.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202229050160 -> 140210812010304 + 140210812010304 [label=AccumulateGrad] + 140210812010064 -> 140210812010016 + 140210812010064 [label=ViewBackward0] + 140210812010352 -> 140210812010064 + 140210812010352 [label=ViewBackward0] + 140210812010448 -> 140210812010352 + 140210812010448 [label=CloneBackward0] + 140210812010256 -> 140210812010448 + 140210812010256 [label=PermuteBackward0] + 140210812022992 -> 140210812010256 + 140210812022992 [label=UnsafeViewBackward0] + 140210812023088 -> 140210812022992 + 140210812023088 [label=BmmBackward0] + 140210812023184 -> 140210812023088 + 140210812023184 [label=ReshapeAliasBackward0] + 140210812023328 -> 140210812023184 + 140210812023328 [label=ExpandBackward0] + 140210812023424 -> 140210812023328 + 140210812023424 [label=ToCopyBackward0] + 140210812023520 -> 140210812023424 + 140210812023520 [label=NativeDropoutBackward0] + 140210812023616 -> 140210812023520 + 140210812023616 [label=SoftmaxBackward0] + 140210812023712 -> 140210812023616 + 140210812023712 [label=AddBackward0] + 140210812023808 -> 140210812023712 + 140210812023808 [label=DivBackward0] + 140210812023904 -> 140210812023808 + 140210812023904 [label=UnsafeViewBackward0] + 140210812024000 -> 140210812023904 + 140210812024000 [label=BmmBackward0] + 140210812024096 -> 140210812024000 + 140210812024096 [label=ReshapeAliasBackward0] + 140210812024240 -> 140210812024096 + 140210812024240 [label=ExpandBackward0] + 140210812024336 -> 140210812024240 + 140210812024336 [label=PermuteBackward0] + 140210812024432 -> 140210812024336 + 140210812024432 [label=ViewBackward0] + 140210812024528 -> 140210812024432 + 140210812024528 [label=ViewBackward0] + 140210812024624 -> 140210812024528 + 140210812024624 [label=AddmmBackward0] + 140210812024720 -> 140210812024624 + 140210812024720 [label=ToCopyBackward0] + 140210812024912 -> 140210812024720 + 140202229050880 [label="encoder.layer.1.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202229050880 -> 140210812024912 + 140210812024912 [label=AccumulateGrad] + 140210812024672 -> 140210812024624 + 140210812024672 [label=ViewBackward0] + 140210812024960 -> 140210812024672 + 140210812024960 [label=ToCopyBackward0] + 140210812009728 -> 140210812024960 + 140210812009728 [label=CatBackward0] + 140210812025104 -> 140210812009728 + 140210812025104 [label=NativeLayerNormBackward0] + 140210812025248 -> 140210812025104 + 140210812025248 [label=AddBackward0] + 140210812025440 -> 140210812025248 + 140210812025440 [label=NativeDropoutBackward0] + 140210812025584 -> 140210812025440 + 140210812025584 [label=ViewBackward0] + 140210812025680 -> 140210812025584 + 140210812025680 [label=AddmmBackward0] + 140210812025776 -> 140210812025680 + 140210812025776 [label=ToCopyBackward0] + 140210812025968 -> 140210812025776 + 140202229067840 [label="encoder.layer.0.experts.dense2.bias + (768)" fillcolor=lightblue] + 140202229067840 -> 140210812025968 + 140210812025968 [label=AccumulateGrad] + 140210812025728 -> 140210812025680 + 140210812025728 [label=ViewBackward0] + 140210812026016 -> 140210812025728 + 140210812026016 [label=GeluBackward0] + 140210812026112 -> 140210812026016 + 140210812026112 [label=ViewBackward0] + 140210812026208 -> 140210812026112 + 140210812026208 [label=AddmmBackward0] + 140210812026304 -> 140210812026208 + 140210812026304 [label=ToCopyBackward0] + 140210812026496 -> 140210812026304 + 140202229068080 [label="encoder.layer.0.experts.dense1.bias + (3072)" fillcolor=lightblue] + 140202229068080 -> 140210812026496 + 140210812026496 [label=AccumulateGrad] + 140210812026256 -> 140210812026208 + 140210812026256 [label=ViewBackward0] + 140210812026544 -> 140210812026256 + 140210812026544 [label=ToCopyBackward0] + 140210812025392 -> 140210812026544 + 140210812025392 [label=SliceBackward0] + 140210812026688 -> 140210812025392 + 140210812026688 [label=SliceBackward0] + 140210812026784 -> 140210812026688 + 140210812026784 [label=NativeLayerNormBackward0] + 140210812026832 -> 140210812026784 + 140210812026832 [label=AddBackward0] + 140210812039424 -> 140210812026832 + 140210812039424 [label=NativeDropoutBackward0] + 140210812039568 -> 140210812039424 + 140210812039568 [label=ViewBackward0] + 140210812039664 -> 140210812039568 + 140210812039664 [label=AddmmBackward0] + 140210812039760 -> 140210812039664 + 140210812039760 [label=ToCopyBackward0] + 140210812039952 -> 140210812039760 + 140202229070000 [label="encoder.layer.0.crossattention.output.dense.bias + (768)" fillcolor=lightblue] + 140202229070000 -> 140210812039952 + 140210812039952 [label=AccumulateGrad] + 140210812039712 -> 140210812039664 + 140210812039712 [label=ViewBackward0] + 140210812040000 -> 140210812039712 + 140210812040000 [label=ViewBackward0] + 140210812040096 -> 140210812040000 + 140210812040096 [label=CloneBackward0] + 140210812040192 -> 140210812040096 + 140210812040192 [label=PermuteBackward0] + 140210812040288 -> 140210812040192 + 140210812040288 [label=UnsafeViewBackward0] + 140210812040384 -> 140210812040288 + 140210812040384 [label=BmmBackward0] + 140210812040480 -> 140210812040384 + 140210812040480 [label=ReshapeAliasBackward0] + 140210812040624 -> 140210812040480 + 140210812040624 [label=ExpandBackward0] + 140210812040720 -> 140210812040624 + 140210812040720 [label=ToCopyBackward0] + 140210812040816 -> 140210812040720 + 140210812040816 [label=NativeDropoutBackward0] + 140210812040912 -> 140210812040816 + 140210812040912 [label=SoftmaxBackward0] + 140210812041008 -> 140210812040912 + 140210812041008 [label=AddBackward0] + 140210812041104 -> 140210812041008 + 140210812041104 [label=DivBackward0] + 140210812041200 -> 140210812041104 + 140210812041200 [label=UnsafeViewBackward0] + 140210812041296 -> 140210812041200 + 140210812041296 [label=BmmBackward0] + 140210812041392 -> 140210812041296 + 140210812041392 [label=ReshapeAliasBackward0] + 140210812041536 -> 140210812041392 + 140210812041536 [label=ExpandBackward0] + 140210812041632 -> 140210812041536 + 140210812041632 [label=PermuteBackward0] + 140210812041728 -> 140210812041632 + 140210812041728 [label=ViewBackward0] + 140210812041824 -> 140210812041728 + 140210812041824 [label=ViewBackward0] + 140210812041920 -> 140210812041824 + 140210812041920 [label=AddmmBackward0] + 140210812042016 -> 140210812041920 + 140210812042016 [label=ToCopyBackward0] + 140210812042208 -> 140210812042016 + 140202229070720 [label="encoder.layer.0.crossattention.self.query.bias + (768)" fillcolor=lightblue] + 140202229070720 -> 140210812042208 + 140210812042208 [label=AccumulateGrad] + 140210812041968 -> 140210812041920 + 140210812041968 [label=ViewBackward0] + 140210812042256 -> 140210812041968 + 140210812042256 [label=ToCopyBackward0] + 140210812039376 -> 140210812042256 + 140210812039376 [label=SliceBackward0] + 140210812042400 -> 140210812039376 + 140210812042400 [label=SliceBackward0] + 140210812042496 -> 140210812042400 + 140210812042496 [label=SliceBackward0] + 140210812042592 -> 140210812042496 + 140210812042592 [label=NativeLayerNormBackward0] + 140210812042688 -> 140210812042592 + 140210812042688 [label=AddBackward0] + 140210812042880 -> 140210812042688 + 140210812042880 [label=NativeDropoutBackward0] + 140210812043024 -> 140210812042880 + 140210812043024 [label=ViewBackward0] + 140210812043120 -> 140210812043024 + 140210812043120 [label=AddmmBackward0] + 140210812043216 -> 140210812043120 + 140210812043216 [label=ToCopyBackward0] + 140210812051664 -> 140210812043216 + 140202229071200 [label="encoder.layer.0.attention.output.dense.bias + (768)" fillcolor=lightblue] + 140202229071200 -> 140210812051664 + 140210812051664 [label=AccumulateGrad] + 140210812043168 -> 140210812043120 + 140210812043168 [label=ViewBackward0] + 140210812051712 -> 140210812043168 + 140210812051712 [label=ViewBackward0] + 140210812051808 -> 140210812051712 + 140210812051808 [label=CloneBackward0] + 140210812051904 -> 140210812051808 + 140210812051904 [label=PermuteBackward0] + 140210812052000 -> 140210812051904 + 140210812052000 [label=UnsafeViewBackward0] + 140210812052096 -> 140210812052000 + 140210812052096 [label=BmmBackward0] + 140210812052192 -> 140210812052096 + 140210812052192 [label=ReshapeAliasBackward0] + 140210812052336 -> 140210812052192 + 140210812052336 [label=ExpandBackward0] + 140210812052432 -> 140210812052336 + 140210812052432 [label=ToCopyBackward0] + 140210812052528 -> 140210812052432 + 140210812052528 [label=NativeDropoutBackward0] + 140210812052624 -> 140210812052528 + 140210812052624 [label=SoftmaxBackward0] + 140210812052720 -> 140210812052624 + 140210812052720 [label=AddBackward0] + 140210812052816 -> 140210812052720 + 140210812052816 [label=DivBackward0] + 140210812052912 -> 140210812052816 + 140210812052912 [label=UnsafeViewBackward0] + 140210812053008 -> 140210812052912 + 140210812053008 [label=BmmBackward0] + 140210812053104 -> 140210812053008 + 140210812053104 [label=ReshapeAliasBackward0] + 140210812053248 -> 140210812053104 + 140210812053248 [label=ExpandBackward0] + 140210812053344 -> 140210812053248 + 140210812053344 [label=PermuteBackward0] + 140210812053440 -> 140210812053344 + 140210812053440 [label=ViewBackward0] + 140210812053536 -> 140210812053440 + 140210812053536 [label=ViewBackward0] + 140210812053632 -> 140210812053536 + 140210812053632 [label=AddmmBackward0] + 140210812053728 -> 140210812053632 + 140210812053728 [label=ToCopyBackward0] + 140210812053920 -> 140210812053728 + 140202228734688 [label="encoder.layer.0.attention.self.query.bias + (768)" fillcolor=lightblue] + 140202228734688 -> 140210812053920 + 140210812053920 [label=AccumulateGrad] + 140210812053680 -> 140210812053632 + 140210812053680 [label=ViewBackward0] + 140210812053968 -> 140210812053680 + 140210812053968 [label=ToCopyBackward0] + 140210812042832 -> 140210812053968 + 140210812042832 [label=NativeDropoutBackward0] + 140210812054112 -> 140210812042832 + 140210812054112 [label=NativeLayerNormBackward0] + 140210812054208 -> 140210812054112 + 140210812054208 [label=CatBackward0] + 140210812054400 -> 140210812054208 + 140210812054400 [label=ExpandBackward0] + 140210812054544 -> 140210812054400 + 140202228561216 [label=" + (1, 32, 768)" fillcolor=lightblue] + 140202228561216 -> 140210812054544 + 140210812054544 [label=AccumulateGrad] + 140210812054352 -> 140210812054208 + 140210812054352 [label=AddBackward0] + 140210812054592 -> 140210812054352 + 140210812054592 [label=EmbeddingBackward0] + 140210812054736 -> 140210812054592 + 140202228561776 [label="embeddings.word_embeddings.weight + (30523, 768)" fillcolor=lightblue] + 140202228561776 -> 140210812054736 + 140210812054736 [label=AccumulateGrad] + 140210812054640 -> 140210812054352 + 140210812054640 [label=EmbeddingBackward0] + 140210812054784 -> 140210812054640 + 140202228735888 [label="embeddings.position_embeddings.weight + (512, 768)" fillcolor=lightblue] + 140202228735888 -> 140210812054784 + 140210812054784 [label=AccumulateGrad] + 140210812054160 -> 140210812054112 + 140202228560576 [label="embeddings.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228560576 -> 140210812054160 + 140210812054160 [label=AccumulateGrad] + 140210812053824 -> 140210812054112 + 140202228560336 [label="embeddings.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228560336 -> 140210812053824 + 140210812053824 [label=AccumulateGrad] + 140210812053152 -> 140210812053632 + 140210812053152 [label=TBackward0] + 140210812053872 -> 140210812053152 + 140210812053872 [label=ToCopyBackward0] + 140210812054304 -> 140210812053872 + 140202228560096 [label="encoder.layer.0.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228560096 -> 140210812054304 + 140210812054304 [label=AccumulateGrad] + 140210812053056 -> 140210812053008 + 140210812053056 [label=ReshapeAliasBackward0] + 140210812053392 -> 140210812053056 + 140210812053392 [label=ExpandBackward0] + 140210812053584 -> 140210812053392 + 140210812053584 [label=TransposeBackward0] + 140210812054064 -> 140210812053584 + 140210812054064 [label=PermuteBackward0] + 140210812054832 -> 140210812054064 + 140210812054832 [label=ViewBackward0] + 140210812054016 -> 140210812054832 + 140210812054016 [label=ViewBackward0] + 140210812054448 -> 140210812054016 + 140210812054448 [label=AddmmBackward0] + 140210812054928 -> 140210812054448 + 140210812054928 [label=ToCopyBackward0] + 140210812055120 -> 140210812054928 + 140202229071680 [label="encoder.layer.0.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202229071680 -> 140210812055120 + 140210812055120 [label=AccumulateGrad] + 140210812054688 -> 140210812054448 + 140210812054688 [label=ViewBackward0] + 140210812055168 -> 140210812054688 + 140210812055168 [label=ToCopyBackward0] + 140210812042832 -> 140210812055168 + 140210812053200 -> 140210812054448 + 140210812053200 [label=TBackward0] + 140210812055024 -> 140210812053200 + 140210812055024 [label=ToCopyBackward0] + 140210812055312 -> 140210812055024 + 140202228734048 [label="encoder.layer.0.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202228734048 -> 140210812055312 + 140210812055312 [label=AccumulateGrad] + 140210812052144 -> 140210812052096 + 140210812052144 [label=ReshapeAliasBackward0] + 140210812052480 -> 140210812052144 + 140210812052480 [label=ExpandBackward0] + 140210812052672 -> 140210812052480 + 140210812052672 [label=PermuteBackward0] + 140210812052864 -> 140210812052672 + 140210812052864 [label=ViewBackward0] + 140210812052240 -> 140210812052864 + 140210812052240 [label=ViewBackward0] + 140210812053488 -> 140210812052240 + 140210812053488 [label=AddmmBackward0] + 140210812054256 -> 140210812053488 + 140210812054256 [label=ToCopyBackward0] + 140210812055264 -> 140210812054256 + 140202229071440 [label="encoder.layer.0.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202229071440 -> 140210812055264 + 140210812055264 [label=AccumulateGrad] + 140210812053776 -> 140210812053488 + 140210812053776 [label=ViewBackward0] + 140210812055072 -> 140210812053776 + 140210812055072 [label=ToCopyBackward0] + 140210812042832 -> 140210812055072 + 140210812052288 -> 140210812053488 + 140210812052288 [label=TBackward0] + 140210812054880 -> 140210812052288 + 140210812054880 [label=ToCopyBackward0] + 140210812055216 -> 140210812054880 + 140202229071760 [label="encoder.layer.0.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202229071760 -> 140210812055216 + 140210812055216 [label=AccumulateGrad] + 140210812042928 -> 140210812043120 + 140210812042928 [label=TBackward0] + 140210812051856 -> 140210812042928 + 140210812051856 [label=ToCopyBackward0] + 140210812052048 -> 140210812051856 + 140202229071520 [label="encoder.layer.0.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202229071520 -> 140210812052048 + 140210812052048 [label=AccumulateGrad] + 140210812042832 -> 140210812042688 + 140210812042640 -> 140210812042592 + 140202229071280 [label="encoder.layer.0.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229071280 -> 140210812042640 + 140210812042640 [label=AccumulateGrad] + 140210812042112 -> 140210812042592 + 140202229070960 [label="encoder.layer.0.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229070960 -> 140210812042112 + 140210812042112 [label=AccumulateGrad] + 140210812041440 -> 140210812041920 + 140210812041440 [label=TBackward0] + 140210812042160 -> 140210812041440 + 140210812042160 [label=ToCopyBackward0] + 140210812042544 -> 140210812042160 + 140202229071040 [label="encoder.layer.0.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202229071040 -> 140210812042544 + 140210812042544 [label=AccumulateGrad] + 140210812041344 -> 140210812041296 + 140210812041344 [label=ReshapeAliasBackward0] + 140210812041680 -> 140210812041344 + 140210812041680 [label=ExpandBackward0] + 140210812041872 -> 140210812041680 + 140210812041872 [label=TransposeBackward0] + 140210812042352 -> 140210812041872 + 140210812042352 [label=PermuteBackward0] + 140210812042784 -> 140210812042352 + 140210812042784 [label=ViewBackward0] + 140210812042304 -> 140210812042784 + 140210812042304 [label=ViewBackward0] + 140210812043072 -> 140210812042304 + 140210812043072 [label=AddmmBackward0] + 140210812041488 -> 140210812043072 + 140210812041488 [label=ToCopyBackward0] + 140210812051760 -> 140210812041488 + 140202229070480 [label="encoder.layer.0.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140202229070480 -> 140210812051760 + 140210812051760 [label=AccumulateGrad] + 140210812051568 -> 140210812043072 + 140210812051568 [label=ViewBackward0] + 140210812052576 -> 140210812051568 + 140210812052576 [label=ToCopyBackward0] + 140210812052960 -> 140210812052576 + 140210812052960 [label=NativeLayerNormBackward0] + 140210812054496 -> 140210812052960 + 140202228735248 [label=" + (1408)" fillcolor=lightblue] + 140202228735248 -> 140210812054496 + 140210812054496 [label=AccumulateGrad] + 140210812053296 -> 140210812052960 + 140202228735488 [label=" + (1408)" fillcolor=lightblue] + 140202228735488 -> 140210812053296 + 140210812053296 [label=AccumulateGrad] + 140210812051520 -> 140210812043072 + 140210812051520 [label=TBackward0] + 140210812051616 -> 140210812051520 + 140210812051616 [label=ToCopyBackward0] + 140210812054976 -> 140210812051616 + 140202229070800 [label="encoder.layer.0.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140202229070800 -> 140210812054976 + 140210812054976 [label=AccumulateGrad] + 140210812040432 -> 140210812040384 + 140210812040432 [label=ReshapeAliasBackward0] + 140210812040768 -> 140210812040432 + 140210812040768 [label=ExpandBackward0] + 140210812040960 -> 140210812040768 + 140210812040960 [label=PermuteBackward0] + 140210812041152 -> 140210812040960 + 140210812041152 [label=ViewBackward0] + 140210812040528 -> 140210812041152 + 140210812040528 [label=ViewBackward0] + 140210812041776 -> 140210812040528 + 140210812041776 [label=AddmmBackward0] + 140210812042448 -> 140210812041776 + 140210812042448 [label=ToCopyBackward0] + 140210812042976 -> 140210812042448 + 140202229070240 [label="encoder.layer.0.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140202229070240 -> 140210812042976 + 140210812042976 [label=AccumulateGrad] + 140210812042064 -> 140210812041776 + 140210812042064 [label=ViewBackward0] + 140210812055360 -> 140210812042064 + 140210812055360 [label=ToCopyBackward0] + 140210812052960 -> 140210812055360 + 140210812040576 -> 140210812041776 + 140210812040576 [label=TBackward0] + 140210812051952 -> 140210812040576 + 140210812051952 [label=ToCopyBackward0] + 140210812052768 -> 140210812051952 + 140202229070560 [label="encoder.layer.0.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140202229070560 -> 140210812052768 + 140210812052768 [label=AccumulateGrad] + 140210812039472 -> 140210812039664 + 140210812039472 [label=TBackward0] + 140210812040144 -> 140210812039472 + 140210812040144 [label=ToCopyBackward0] + 140210812040336 -> 140210812040144 + 140202229070320 [label="encoder.layer.0.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202229070320 -> 140210812040336 + 140210812040336 [label=AccumulateGrad] + 140210812039376 -> 140210812026832 + 140210812026400 -> 140210812026784 + 140202229070080 [label="encoder.layer.0.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229070080 -> 140210812026400 + 140210812026400 [label=AccumulateGrad] + 140210812039232 -> 140210812026784 + 140202229069760 [label="encoder.layer.0.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229069760 -> 140210812039232 + 140210812039232 [label=AccumulateGrad] + 140210812025920 -> 140210812026208 + 140210812025920 [label=TBackward0] + 140210812026448 -> 140210812025920 + 140210812026448 [label=ToCopyBackward0] + 140210812026736 -> 140210812026448 + 140202229068400 [label="encoder.layer.0.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202229068400 -> 140210812026736 + 140210812026736 [label=AccumulateGrad] + 140210812025488 -> 140210812025680 + 140210812025488 [label=TBackward0] + 140210812026160 -> 140210812025488 + 140210812026160 [label=ToCopyBackward0] + 140210812026640 -> 140210812026160 + 140202229068160 [label="encoder.layer.0.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202229068160 -> 140210812026640 + 140210812026640 [label=AccumulateGrad] + 140210812025392 -> 140210812025248 + 140210812025200 -> 140210812025104 + 140202229067920 [label="encoder.layer.0.expert_ln.weight + (768)" fillcolor=lightblue] + 140202229067920 -> 140210812025200 + 140210812025200 [label=AccumulateGrad] + 140210812025152 -> 140210812025104 + 140202229051120 [label="encoder.layer.0.expert_ln.bias + (768)" fillcolor=lightblue] + 140202229051120 -> 140210812025152 + 140210812025152 [label=AccumulateGrad] + 140210812024864 -> 140210812009728 + 140210812024864 [label=NativeLayerNormBackward0] + 140210812025536 -> 140210812024864 + 140210812025536 [label=AddBackward0] + 140210812026352 -> 140210812025536 + 140210812026352 [label=NativeDropoutBackward0] + 140210812026064 -> 140210812026352 + 140210812026064 [label=ViewBackward0] + 140210812039280 -> 140210812026064 + 140210812039280 [label=AddmmBackward0] + 140210812039808 -> 140210812039280 + 140210812039808 [label=ToCopyBackward0] + 140210812039904 -> 140210812039808 + 140202229069280 [label="encoder.layer.0.output.dense.bias + (768)" fillcolor=lightblue] + 140202229069280 -> 140210812039904 + 140210812039904 [label=AccumulateGrad] + 140210812039616 -> 140210812039280 + 140210812039616 [label=ViewBackward0] + 140210812040048 -> 140210812039616 + 140210812040048 [label=GeluBackward0] + 140210812041056 -> 140210812040048 + 140210812041056 [label=ViewBackward0] + 140210812041584 -> 140210812041056 + 140210812041584 [label=AddmmBackward0] + 140210812042736 -> 140210812041584 + 140210812042736 [label=ToCopyBackward0] + 140210812055456 -> 140210812042736 + 140202229069520 [label="encoder.layer.0.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202229069520 -> 140210812055456 + 140210812055456 [label=AccumulateGrad] + 140210812040672 -> 140210812041584 + 140210812040672 [label=ViewBackward0] + 140210812055408 -> 140210812040672 + 140210812055408 [label=ToCopyBackward0] + 140210812025872 -> 140210812055408 + 140210812025872 [label=SliceBackward0] + 140210812092672 -> 140210812025872 + 140210812092672 [label=SliceBackward0] + 140210812092768 -> 140210812092672 + 140210812092768 [label=SliceBackward0] + 140210812042592 -> 140210812092768 + 140210812055504 -> 140210812041584 + 140210812055504 [label=TBackward0] + 140210812092576 -> 140210812055504 + 140210812092576 [label=ToCopyBackward0] + 140210812092864 -> 140210812092576 + 140202229069840 [label="encoder.layer.0.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202229069840 -> 140210812092864 + 140210812092864 [label=AccumulateGrad] + 140210812039520 -> 140210812039280 + 140210812039520 [label=TBackward0] + 140210812041248 -> 140210812039520 + 140210812041248 [label=ToCopyBackward0] + 140210812052384 -> 140210812041248 + 140202229069600 [label="encoder.layer.0.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202229069600 -> 140210812052384 + 140210812052384 [label=AccumulateGrad] + 140210812025872 -> 140210812025536 + 140210812025344 -> 140210812024864 + 140202229069360 [label="encoder.layer.0.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229069360 -> 140210812025344 + 140210812025344 [label=AccumulateGrad] + 140210812025296 -> 140210812024864 + 140202229069040 [label="encoder.layer.0.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229069040 -> 140210812025296 + 140210812025296 [label=AccumulateGrad] + 140210812024144 -> 140210812024624 + 140210812024144 [label=TBackward0] + 140210812024816 -> 140210812024144 + 140210812024816 [label=ToCopyBackward0] + 140210812025824 -> 140210812024816 + 140202229051200 [label="encoder.layer.1.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202229051200 -> 140210812025824 + 140210812025824 [label=AccumulateGrad] + 140210812024048 -> 140210812024000 + 140210812024048 [label=ReshapeAliasBackward0] + 140210812024384 -> 140210812024048 + 140210812024384 [label=ExpandBackward0] + 140210812024576 -> 140210812024384 + 140210812024576 [label=TransposeBackward0] + 140210812025056 -> 140210812024576 + 140210812025056 [label=PermuteBackward0] + 140210812026592 -> 140210812025056 + 140210812026592 [label=ViewBackward0] + 140210812025008 -> 140210812026592 + 140210812025008 [label=ViewBackward0] + 140210812040240 -> 140210812025008 + 140210812040240 [label=AddmmBackward0] + 140210812040864 -> 140210812040240 + 140210812040864 [label=ToCopyBackward0] + 140210812092528 -> 140210812040864 + 140202229050640 [label="encoder.layer.1.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202229050640 -> 140210812092528 + 140210812092528 [label=AccumulateGrad] + 140210812039328 -> 140210812040240 + 140210812039328 [label=ViewBackward0] + 140210812092912 -> 140210812039328 + 140210812092912 [label=ToCopyBackward0] + 140210812009728 -> 140210812092912 + 140210812092480 -> 140210812040240 + 140210812092480 [label=TBackward0] + 140210812092624 -> 140210812092480 + 140210812092624 [label=ToCopyBackward0] + 140210812093056 -> 140210812092624 + 140202229050960 [label="encoder.layer.1.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202229050960 -> 140210812093056 + 140210812093056 [label=AccumulateGrad] + 140210812023136 -> 140210812023088 + 140210812023136 [label=ReshapeAliasBackward0] + 140210812023472 -> 140210812023136 + 140210812023472 [label=ExpandBackward0] + 140210812023664 -> 140210812023472 + 140210812023664 [label=PermuteBackward0] + 140210812023856 -> 140210812023664 + 140210812023856 [label=ViewBackward0] + 140210812023232 -> 140210812023856 + 140210812023232 [label=ViewBackward0] + 140210812024480 -> 140210812023232 + 140210812024480 [label=AddmmBackward0] + 140210812025632 -> 140210812024480 + 140210812025632 [label=ToCopyBackward0] + 140210812039856 -> 140210812025632 + 140202229050400 [label="encoder.layer.1.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202229050400 -> 140210812039856 + 140210812039856 [label=AccumulateGrad] + 140210812024768 -> 140210812024480 + 140210812024768 [label=ViewBackward0] + 140210812092816 -> 140210812024768 + 140210812092816 [label=ToCopyBackward0] + 140210812009728 -> 140210812092816 + 140210812023280 -> 140210812024480 + 140210812023280 [label=TBackward0] + 140210812092720 -> 140210812023280 + 140210812092720 [label=ToCopyBackward0] + 140210812092960 -> 140210812092720 + 140202229050720 [label="encoder.layer.1.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202229050720 -> 140210812092960 + 140210812092960 [label=AccumulateGrad] + 140210812009824 -> 140210812010016 + 140210812009824 [label=TBackward0] + 140210812010208 -> 140210812009824 + 140210812010208 [label=ToCopyBackward0] + 140210812023040 -> 140210812010208 + 140202229050480 [label="encoder.layer.1.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202229050480 -> 140210812023040 + 140210812023040 [label=AccumulateGrad] + 140210812009728 -> 140210812009584 + 140210812009536 -> 140210812009488 + 140202229050240 [label="encoder.layer.1.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229050240 -> 140210812009536 + 140210812009536 [label=AccumulateGrad] + 140210812008816 -> 140210812009488 + 140202229049920 [label="encoder.layer.1.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229049920 -> 140210812008816 + 140210812008816 [label=AccumulateGrad] + 140210812008336 -> 140210812008624 + 140210812008336 [label=TBackward0] + 140210812008864 -> 140210812008336 + 140210812008864 [label=ToCopyBackward0] + 140210812009248 -> 140210812008864 + 140202229048560 [label="encoder.layer.1.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202229048560 -> 140210812009248 + 140210812009248 [label=AccumulateGrad] + 140210812007904 -> 140210812008096 + 140210812007904 [label=TBackward0] + 140210812008576 -> 140210812007904 + 140210812008576 [label=ToCopyBackward0] + 140210812009056 -> 140210812008576 + 140202229048320 [label="encoder.layer.1.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202229048320 -> 140210812009056 + 140210812009056 [label=AccumulateGrad] + 140210812007808 -> 140210812007664 + 140210812007616 -> 140210812007520 + 140202229048080 [label="encoder.layer.1.expert_ln.weight + (768)" fillcolor=lightblue] + 140202229048080 -> 140210812007616 + 140210812007616 [label=AccumulateGrad] + 140210812007568 -> 140210812007520 + 140202229047760 [label="encoder.layer.1.expert_ln.bias + (768)" fillcolor=lightblue] + 140202229047760 -> 140210812007568 + 140210812007568 [label=AccumulateGrad] + 140210812007280 -> 140210811996240 + 140210812007280 [label=NativeLayerNormBackward0] + 140210812007952 -> 140210812007280 + 140210812007952 [label=AddBackward0] + 140210812008768 -> 140210812007952 + 140210812008768 [label=NativeDropoutBackward0] + 140210812008480 -> 140210812008768 + 140210812008480 [label=ViewBackward0] + 140210812009008 -> 140210812008480 + 140210812009008 [label=AddmmBackward0] + 140210812009680 -> 140210812009008 + 140210812009680 [label=ToCopyBackward0] + 140210812010400 -> 140210812009680 + 140202229049440 [label="encoder.layer.1.output.dense.bias + (768)" fillcolor=lightblue] + 140202229049440 -> 140210812010400 + 140210812010400 [label=AccumulateGrad] + 140210812009632 -> 140210812009008 + 140210812009632 [label=ViewBackward0] + 140210812009968 -> 140210812009632 + 140210812009968 [label=GeluBackward0] + 140210812022848 -> 140210812009968 + 140210812022848 [label=ViewBackward0] + 140210812023568 -> 140210812022848 + 140210812023568 [label=AddmmBackward0] + 140210812023952 -> 140210812023568 + 140210812023952 [label=ToCopyBackward0] + 140210812024192 -> 140210812023952 + 140202229049680 [label="encoder.layer.1.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202229049680 -> 140210812024192 + 140210812024192 [label=AccumulateGrad] + 140210812023760 -> 140210812023568 + 140210812023760 [label=ViewBackward0] + 140210812093248 -> 140210812023760 + 140210812093248 [label=ToCopyBackward0] + 140210812008288 -> 140210812093248 + 140210812008288 [label=SliceBackward0] + 140210812093296 -> 140210812008288 + 140210812093296 [label=SliceBackward0] + 140210812093392 -> 140210812093296 + 140210812093392 [label=SliceBackward0] + 140210812009488 -> 140210812093392 + 140210812023376 -> 140210812023568 + 140210812023376 [label=TBackward0] + 140210812093008 -> 140210812023376 + 140210812093008 [label=ToCopyBackward0] + 140210812093488 -> 140210812093008 + 140202229050000 [label="encoder.layer.1.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202229050000 -> 140210812093488 + 140210812093488 [label=AccumulateGrad] + 140210812009440 -> 140210812009008 + 140210812009440 [label=TBackward0] + 140210812010160 -> 140210812009440 + 140210812010160 [label=ToCopyBackward0] + 140210812024288 -> 140210812010160 + 140202229049760 [label="encoder.layer.1.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202229049760 -> 140210812024288 + 140210812024288 [label=AccumulateGrad] + 140210812008288 -> 140210812007952 + 140210812007760 -> 140210812007280 + 140202229049520 [label="encoder.layer.1.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229049520 -> 140210812007760 + 140210812007760 [label=AccumulateGrad] + 140210812007712 -> 140210812007280 + 140202229049200 [label="encoder.layer.1.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229049200 -> 140210812007712 + 140210812007712 [label=AccumulateGrad] + 140210812006560 -> 140210812007040 + 140210812006560 [label=TBackward0] + 140210812007232 -> 140210812006560 + 140210812007232 [label=ToCopyBackward0] + 140210812008240 -> 140210812007232 + 140202229047840 [label="encoder.layer.2.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202229047840 -> 140210812008240 + 140210812008240 [label=AccumulateGrad] + 140210812006464 -> 140210811998160 + 140210812006464 [label=ReshapeAliasBackward0] + 140210812006800 -> 140210812006464 + 140210812006800 [label=ExpandBackward0] + 140210812006992 -> 140210812006800 + 140210812006992 [label=TransposeBackward0] + 140210812007472 -> 140210812006992 + 140210812007472 [label=PermuteBackward0] + 140210812009344 -> 140210812007472 + 140210812009344 [label=ViewBackward0] + 140210812007424 -> 140210812009344 + 140210812007424 [label=ViewBackward0] + 140210812009872 -> 140210812007424 + 140210812009872 [label=AddmmBackward0] + 140210812022944 -> 140210812009872 + 140210812022944 [label=ToCopyBackward0] + 140210812093200 -> 140210812022944 + 140202229047360 [label="encoder.layer.2.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202229047360 -> 140210812093200 + 140210812093200 [label=AccumulateGrad] + 140210812022896 -> 140210812009872 + 140210812022896 [label=ViewBackward0] + 140210812093536 -> 140210812022896 + 140210812093536 [label=ToCopyBackward0] + 140210811996240 -> 140210812093536 + 140210812093104 -> 140210812009872 + 140210812093104 [label=TBackward0] + 140210812093152 -> 140210812093104 + 140210812093152 [label=ToCopyBackward0] + 140210812093680 -> 140210812093152 + 140202229047600 [label="encoder.layer.2.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202229047600 -> 140210812093680 + 140210812093680 [label=AccumulateGrad] + 140210811997296 -> 140210811997248 + 140210811997296 [label=ReshapeAliasBackward0] + 140210811997632 -> 140210811997296 + 140210811997632 [label=ExpandBackward0] + 140210811997824 -> 140210811997632 + 140210811997824 [label=PermuteBackward0] + 140210811998016 -> 140210811997824 + 140210811998016 [label=ViewBackward0] + 140210811998112 -> 140210811998016 + 140210811998112 [label=ViewBackward0] + 140210812006896 -> 140210811998112 + 140210812006896 [label=AddmmBackward0] + 140210812008048 -> 140210812006896 + 140210812008048 [label=ToCopyBackward0] + 140210812006608 -> 140210812008048 + 140202229042848 [label="encoder.layer.2.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202229042848 -> 140210812006608 + 140210812006608 [label=AccumulateGrad] + 140210812007184 -> 140210812006896 + 140210812007184 [label=ViewBackward0] + 140210812093440 -> 140210812007184 + 140210812093440 [label=ToCopyBackward0] + 140210811996240 -> 140210812093440 + 140210812006512 -> 140210812006896 + 140210812006512 [label=TBackward0] + 140210812093344 -> 140210812006512 + 140210812093344 [label=ToCopyBackward0] + 140210812093584 -> 140210812093344 + 140202229043088 [label="encoder.layer.2.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202229043088 -> 140210812093584 + 140210812093584 [label=AccumulateGrad] + 140210811996336 -> 140210811996528 + 140210811996336 [label=TBackward0] + 140210811997008 -> 140210811996336 + 140210811997008 [label=ToCopyBackward0] + 140210811997200 -> 140210811997008 + 140202229042928 [label="encoder.layer.2.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202229042928 -> 140210811997200 + 140210811997200 [label=AccumulateGrad] + 140210811996240 -> 140210811996096 + 140210811996048 -> 140210811996000 + 140202229042688 [label="encoder.layer.2.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229042688 -> 140210811996048 + 140210811996048 [label=AccumulateGrad] + 140210811995520 -> 140210811996000 + 140202229042368 [label="encoder.layer.2.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229042368 -> 140210811995520 + 140210811995520 [label=AccumulateGrad] + 140210811994848 -> 140210811995328 + 140210811994848 [label=TBackward0] + 140210811995568 -> 140210811994848 + 140210811995568 [label=ToCopyBackward0] + 140210811995952 -> 140210811995568 + 140202229042448 [label="encoder.layer.2.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202229042448 -> 140210811995952 + 140210811995952 [label=AccumulateGrad] + 140210811994752 -> 140210811994704 + 140210811994752 [label=ReshapeAliasBackward0] + 140210811995088 -> 140210811994752 + 140210811995088 [label=ExpandBackward0] + 140210811995280 -> 140210811995088 + 140210811995280 [label=TransposeBackward0] + 140210811995760 -> 140210811995280 + 140210811995760 [label=PermuteBackward0] + 140210811996192 -> 140210811995760 + 140210811996192 [label=ViewBackward0] + 140210811995712 -> 140210811996192 + 140210811995712 [label=ViewBackward0] + 140210811996480 -> 140210811995712 + 140210811996480 [label=AddmmBackward0] + 140210811996720 -> 140210811996480 + 140210811996720 [label=ToCopyBackward0] + 140210811996912 -> 140210811996720 + 140202229041888 [label="encoder.layer.2.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140202229041888 -> 140210811996912 + 140210811996912 [label=AccumulateGrad] + 140210811996672 -> 140210811996480 + 140210811996672 [label=ViewBackward0] + 140210811997728 -> 140210811996672 + 140210811997728 [label=ToCopyBackward0] + 140210812052960 -> 140210811997728 + 140210811994896 -> 140210811996480 + 140210811994896 [label=TBackward0] + 140210811997536 -> 140210811994896 + 140210811997536 [label=ToCopyBackward0] + 140210811996768 -> 140210811997536 + 140202229042208 [label="encoder.layer.2.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140202229042208 -> 140210811996768 + 140210811996768 [label=AccumulateGrad] + 140210811977392 -> 140210811977344 + 140210811977392 [label=ReshapeAliasBackward0] + 140210811977632 -> 140210811977392 + 140210811977632 [label=ExpandBackward0] + 140210811994368 -> 140210811977632 + 140210811994368 [label=PermuteBackward0] + 140210811994560 -> 140210811994368 + 140210811994560 [label=ViewBackward0] + 140210811994176 -> 140210811994560 + 140210811994176 [label=ViewBackward0] + 140210811995184 -> 140210811994176 + 140210811995184 [label=AddmmBackward0] + 140210811995856 -> 140210811995184 + 140210811995856 [label=ToCopyBackward0] + 140210811997440 -> 140210811995856 + 140202229041648 [label="encoder.layer.2.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140202229041648 -> 140210811997440 + 140210811997440 [label=AccumulateGrad] + 140210811995472 -> 140210811995184 + 140210811995472 [label=ViewBackward0] + 140210811996384 -> 140210811995472 + 140210811996384 [label=ToCopyBackward0] + 140210812052960 -> 140210811996384 + 140210811994224 -> 140210811995184 + 140210811994224 [label=TBackward0] + 140210812009152 -> 140210811994224 + 140210812009152 [label=ToCopyBackward0] + 140210811997104 -> 140210812009152 + 140202229041968 [label="encoder.layer.2.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140202229041968 -> 140210811997104 + 140210811997104 [label=AccumulateGrad] + 140210811976432 -> 140210811976624 + 140210811976432 [label=TBackward0] + 140210811977104 -> 140210811976432 + 140210811977104 [label=ToCopyBackward0] + 140210811977296 -> 140210811977104 + 140202229041728 [label="encoder.layer.2.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202229041728 -> 140210811977296 + 140210811977296 [label=AccumulateGrad] + 140210811976336 -> 140210811976192 + 140210811976144 -> 140210811976096 + 140202229041488 [label="encoder.layer.2.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229041488 -> 140210811976144 + 140210811976144 [label=AccumulateGrad] + 140210811975712 -> 140210811976096 + 140202229041168 [label="encoder.layer.2.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229041168 -> 140210811975712 + 140210811975712 [label=AccumulateGrad] + 140210811975232 -> 140210811975520 + 140210811975232 [label=TBackward0] + 140210811975760 -> 140210811975232 + 140210811975760 [label=ToCopyBackward0] + 140210811976240 -> 140210811975760 + 140202229039808 [label="encoder.layer.2.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202229039808 -> 140210811976240 + 140210811976240 [label=AccumulateGrad] + 140210811974800 -> 140210811974992 + 140210811974800 [label=TBackward0] + 140210811975472 -> 140210811974800 + 140210811975472 [label=ToCopyBackward0] + 140210811975952 -> 140210811975472 + 140202229039568 [label="encoder.layer.2.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202229039568 -> 140210811975952 + 140210811975952 [label=AccumulateGrad] + 140210811974704 -> 140210811974560 + 140210811974512 -> 140210811974416 + 140202229039328 [label="encoder.layer.2.expert_ln.weight + (768)" fillcolor=lightblue] + 140202229039328 -> 140210811974512 + 140210811974512 [label=AccumulateGrad] + 140210811974464 -> 140210811974416 + 140202229026624 [label="encoder.layer.2.expert_ln.bias + (768)" fillcolor=lightblue] + 140202229026624 -> 140210811974464 + 140210811974464 [label=AccumulateGrad] + 140210811974176 -> 140210811959040 + 140210811974176 [label=NativeLayerNormBackward0] + 140210811974848 -> 140210811974176 + 140210811974848 [label=AddBackward0] + 140210811975664 -> 140210811974848 + 140210811975664 [label=NativeDropoutBackward0] + 140210811975376 -> 140210811975664 + 140210811975376 [label=ViewBackward0] + 140210811975904 -> 140210811975376 + 140210811975904 [label=AddmmBackward0] + 140210811976768 -> 140210811975904 + 140210811976768 [label=ToCopyBackward0] + 140210811976864 -> 140210811976768 + 140202229040688 [label="encoder.layer.2.output.dense.bias + (768)" fillcolor=lightblue] + 140202229040688 -> 140210811976864 + 140210811976864 [label=AccumulateGrad] + 140210811976576 -> 140210811975904 + 140210811976576 [label=ViewBackward0] + 140210812006704 -> 140210811976576 + 140210812006704 [label=GeluBackward0] + 140210811977200 -> 140210812006704 + 140210811977200 [label=ViewBackward0] + 140210811994656 -> 140210811977200 + 140210811994656 [label=AddmmBackward0] + 140210811996144 -> 140210811994656 + 140210811996144 [label=ToCopyBackward0] + 140210812093728 -> 140210811996144 + 140202229040928 [label="encoder.layer.2.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202229040928 -> 140210812093728 + 140210812093728 [label=AccumulateGrad] + 140210811994992 -> 140210811994656 + 140210811994992 [label=ViewBackward0] + 140210812093824 -> 140210811994992 + 140210812093824 [label=ToCopyBackward0] + 140210811975184 -> 140210812093824 + 140210811975184 [label=SliceBackward0] + 140210812093968 -> 140210811975184 + 140210812093968 [label=SliceBackward0] + 140210812094064 -> 140210812093968 + 140210812094064 [label=SliceBackward0] + 140210811996000 -> 140210812094064 + 140210811994272 -> 140210811994656 + 140210811994272 [label=TBackward0] + 140210812093632 -> 140210811994272 + 140210812093632 [label=ToCopyBackward0] + 140210812094160 -> 140210812093632 + 140202229041248 [label="encoder.layer.2.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202229041248 -> 140210812094160 + 140210812094160 [label=AccumulateGrad] + 140210811976480 -> 140210811975904 + 140210811976480 [label=TBackward0] + 140210811977536 -> 140210811976480 + 140210811977536 [label=ToCopyBackward0] + 140210811997920 -> 140210811977536 + 140202229041008 [label="encoder.layer.2.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202229041008 -> 140210811997920 + 140210811997920 [label=AccumulateGrad] + 140210811975184 -> 140210811974848 + 140210811974656 -> 140210811974176 + 140202229040768 [label="encoder.layer.2.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229040768 -> 140210811974656 + 140210811974656 [label=AccumulateGrad] + 140210811974608 -> 140210811974176 + 140202229040448 [label="encoder.layer.2.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229040448 -> 140210811974608 + 140210811974608 [label=AccumulateGrad] + 140210811973696 -> 140210811973936 + 140210811973696 [label=TBackward0] + 140210811974128 -> 140210811973696 + 140210811974128 [label=ToCopyBackward0] + 140210811975136 -> 140210811974128 + 140202229026704 [label="encoder.layer.3.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202229026704 -> 140210811975136 + 140210811975136 [label=AccumulateGrad] + 140210811961008 -> 140210811960960 + 140210811961008 [label=ReshapeAliasBackward0] + 140210811961248 -> 140210811961008 + 140210811961248 [label=ExpandBackward0] + 140210811973888 -> 140210811961248 + 140210811973888 [label=TransposeBackward0] + 140210811974368 -> 140210811973888 + 140210811974368 [label=PermuteBackward0] + 140210811976288 -> 140210811974368 + 140210811976288 [label=ViewBackward0] + 140210811974320 -> 140210811976288 + 140210811974320 [label=ViewBackward0] + 140210811977008 -> 140210811974320 + 140210811977008 [label=AddmmBackward0] + 140210811994464 -> 140210811977008 + 140210811994464 [label=ToCopyBackward0] + 140210812093776 -> 140210811994464 + 140202229026144 [label="encoder.layer.3.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202229026144 -> 140210812093776 + 140210812093776 [label=AccumulateGrad] + 140210811973744 -> 140210811977008 + 140210811973744 [label=ViewBackward0] + 140210812094208 -> 140210811973744 + 140210812094208 [label=ToCopyBackward0] + 140210811959040 -> 140210812094208 + 140210812093872 -> 140210811977008 + 140210812093872 [label=TBackward0] + 140210812093920 -> 140210812093872 + 140210812093920 [label=ToCopyBackward0] + 140210812094352 -> 140210812093920 + 140202229026464 [label="encoder.layer.3.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202229026464 -> 140210812094352 + 140210812094352 [label=AccumulateGrad] + 140210811960096 -> 140210811960048 + 140210811960096 [label=ReshapeAliasBackward0] + 140210811960432 -> 140210811960096 + 140210811960432 [label=ExpandBackward0] + 140210811960624 -> 140210811960432 + 140210811960624 [label=PermuteBackward0] + 140210811960816 -> 140210811960624 + 140210811960816 [label=ViewBackward0] + 140210811960192 -> 140210811960816 + 140210811960192 [label=ViewBackward0] + 140210811961152 -> 140210811960192 + 140210811961152 [label=AddmmBackward0] + 140210811974944 -> 140210811961152 + 140210811974944 [label=ToCopyBackward0] + 140210811976816 -> 140210811974944 + 140202229025904 [label="encoder.layer.3.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202229025904 -> 140210811976816 + 140210811976816 [label=AccumulateGrad] + 140210811974080 -> 140210811961152 + 140210811974080 [label=ViewBackward0] + 140210812094112 -> 140210811974080 + 140210812094112 [label=ToCopyBackward0] + 140210811959040 -> 140210812094112 + 140210811973792 -> 140210811961152 + 140210811973792 [label=TBackward0] + 140210812094016 -> 140210811973792 + 140210812094016 [label=ToCopyBackward0] + 140210812094256 -> 140210812094016 + 140202229026224 [label="encoder.layer.3.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202229026224 -> 140210812094256 + 140210812094256 [label=AccumulateGrad] + 140210811959136 -> 140210811959328 + 140210811959136 [label=TBackward0] + 140210811959808 -> 140210811959136 + 140210811959808 [label=ToCopyBackward0] + 140210811960000 -> 140210811959808 + 140202229025984 [label="encoder.layer.3.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202229025984 -> 140210811960000 + 140210811960000 [label=AccumulateGrad] + 140210811959040 -> 140210811958896 + 140210811958848 -> 140210811958800 + 140202229025744 [label="encoder.layer.3.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229025744 -> 140210811958848 + 140210811958848 [label=AccumulateGrad] + 140210811958128 -> 140210811958800 + 140202229025424 [label="encoder.layer.3.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229025424 -> 140210811958128 + 140210811958128 [label=AccumulateGrad] + 140210811957648 -> 140210811957936 + 140210811957648 [label=TBackward0] + 140210811958176 -> 140210811957648 + 140210811958176 [label=ToCopyBackward0] + 140210811958560 -> 140210811958176 + 140202229024064 [label="encoder.layer.3.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202229024064 -> 140210811958560 + 140210811958560 [label=AccumulateGrad] + 140210811957312 -> 140210811957408 + 140210811957312 [label=TBackward0] + 140210811957888 -> 140210811957312 + 140210811957888 [label=ToCopyBackward0] + 140210811958368 -> 140210811957888 + 140202229023824 [label="encoder.layer.3.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202229023824 -> 140210811958368 + 140210811958368 [label=AccumulateGrad] + 140210811944768 -> 140210811944624 + 140210811944576 -> 140210811944480 + 140202229023584 [label="encoder.layer.3.expert_ln.weight + (768)" fillcolor=lightblue] + 140202229023584 -> 140210811944576 + 140210811944576 [label=AccumulateGrad] + 140210811944528 -> 140210811944480 + 140202229023264 [label="encoder.layer.3.expert_ln.bias + (768)" fillcolor=lightblue] + 140202229023264 -> 140210811944528 + 140210811944528 [label=AccumulateGrad] + 140210811944240 -> 140210811941456 + 140210811944240 [label=NativeLayerNormBackward0] + 140210811944864 -> 140210811944240 + 140210811944864 [label=AddBackward0] + 140210811958080 -> 140210811944864 + 140210811958080 [label=NativeDropoutBackward0] + 140210811957792 -> 140210811958080 + 140210811957792 [label=ViewBackward0] + 140210811958320 -> 140210811957792 + 140210811958320 [label=AddmmBackward0] + 140210811958992 -> 140210811958320 + 140210811958992 [label=ToCopyBackward0] + 140210811959520 -> 140210811958992 + 140202229024944 [label="encoder.layer.3.output.dense.bias + (768)" fillcolor=lightblue] + 140202229024944 -> 140210811959520 + 140210811959520 [label=AccumulateGrad] + 140210811958944 -> 140210811958320 + 140210811958944 [label=ViewBackward0] + 140210811959904 -> 140210811958944 + 140210811959904 [label=GeluBackward0] + 140210811959568 -> 140210811959904 + 140210811959568 [label=ViewBackward0] + 140210811960528 -> 140210811959568 + 140210811960528 [label=AddmmBackward0] + 140210811960912 -> 140210811960528 + 140210811960912 [label=ToCopyBackward0] + 140210811976048 -> 140210811960912 + 140202229025184 [label="encoder.layer.3.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202229025184 -> 140210811976048 + 140210811976048 [label=AccumulateGrad] + 140210811960720 -> 140210811960528 + 140210811960720 [label=ViewBackward0] + 140210812094544 -> 140210811960720 + 140210812094544 [label=ToCopyBackward0] + 140210811957600 -> 140210812094544 + 140210811957600 [label=SliceBackward0] + 140210812094592 -> 140210811957600 + 140210812094592 [label=SliceBackward0] + 140210812094688 -> 140210812094592 + 140210812094688 [label=SliceBackward0] + 140210811958800 -> 140210812094688 + 140210811959472 -> 140210811960528 + 140210811959472 [label=TBackward0] + 140210812094304 -> 140210811959472 + 140210812094304 [label=ToCopyBackward0] + 140210812094784 -> 140210812094304 + 140202229025504 [label="encoder.layer.3.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202229025504 -> 140210812094784 + 140210812094784 [label=AccumulateGrad] + 140210811958752 -> 140210811958320 + 140210811958752 [label=TBackward0] + 140210811959712 -> 140210811958752 + 140210811959712 [label=ToCopyBackward0] + 140210811960240 -> 140210811959712 + 140202229025264 [label="encoder.layer.3.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202229025264 -> 140210811960240 + 140210811960240 [label=AccumulateGrad] + 140210811957600 -> 140210811944864 + 140210811944720 -> 140210811944240 + 140202229025024 [label="encoder.layer.3.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229025024 -> 140210811944720 + 140210811944720 [label=AccumulateGrad] + 140210811944672 -> 140210811944240 + 140202229024704 [label="encoder.layer.3.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229024704 -> 140210811944672 + 140210811944672 [label=AccumulateGrad] + 140210811943520 -> 140210811944000 + 140210811943520 [label=TBackward0] + 140210811944192 -> 140210811943520 + 140210811944192 [label=ToCopyBackward0] + 140210811944384 -> 140210811944192 + 140202229023344 [label="encoder.layer.4.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202229023344 -> 140210811944384 + 140210811944384 [label=AccumulateGrad] + 140210811943424 -> 140210811943376 + 140210811943424 [label=ReshapeAliasBackward0] + 140210811943760 -> 140210811943424 + 140210811943760 [label=ExpandBackward0] + 140210811943952 -> 140210811943760 + 140210811943952 [label=TransposeBackward0] + 140210811944432 -> 140210811943952 + 140210811944432 [label=PermuteBackward0] + 140210811943568 -> 140210811944432 + 140210811943568 [label=ViewBackward0] + 140210811957360 -> 140210811943568 + 140210811957360 [label=ViewBackward0] + 140210811959280 -> 140210811957360 + 140210811959280 [label=AddmmBackward0] + 140210811960336 -> 140210811959280 + 140210811960336 [label=ToCopyBackward0] + 140210812094496 -> 140210811960336 + 140202229022784 [label="encoder.layer.4.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202229022784 -> 140210812094496 + 140210812094496 [label=AccumulateGrad] + 140210811957552 -> 140210811959280 + 140210811957552 [label=ViewBackward0] + 140210812094832 -> 140210811957552 + 140210812094832 [label=ToCopyBackward0] + 140210811941456 -> 140210812094832 + 140210812094400 -> 140210811959280 + 140210812094400 [label=TBackward0] + 140210812094448 -> 140210812094400 + 140210812094448 [label=ToCopyBackward0] + 140210812094976 -> 140210812094448 + 140202229023104 [label="encoder.layer.4.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202229023104 -> 140210812094976 + 140210812094976 [label=AccumulateGrad] + 140210811942512 -> 140210811942464 + 140210811942512 [label=ReshapeAliasBackward0] + 140210811942848 -> 140210811942512 + 140210811942848 [label=ExpandBackward0] + 140210811943040 -> 140210811942848 + 140210811943040 [label=PermuteBackward0] + 140210811943232 -> 140210811943040 + 140210811943232 [label=ViewBackward0] + 140210811942608 -> 140210811943232 + 140210811942608 [label=ViewBackward0] + 140210811943856 -> 140210811942608 + 140210811943856 [label=AddmmBackward0] + 140210811944144 -> 140210811943856 + 140210811944144 [label=ToCopyBackward0] + 140210811959184 -> 140210811944144 + 140202229014256 [label="encoder.layer.4.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202229014256 -> 140210811959184 + 140210811959184 [label=AccumulateGrad] + 140210811942656 -> 140210811943856 + 140210811942656 [label=ViewBackward0] + 140210812094736 -> 140210811942656 + 140210812094736 [label=ToCopyBackward0] + 140210811941456 -> 140210812094736 + 140210811958656 -> 140210811943856 + 140210811958656 [label=TBackward0] + 140210812094640 -> 140210811958656 + 140210812094640 [label=ToCopyBackward0] + 140210812094880 -> 140210812094640 + 140202229022864 [label="encoder.layer.4.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202229022864 -> 140210812094880 + 140210812094880 [label=AccumulateGrad] + 140210811941552 -> 140210811941744 + 140210811941552 [label=TBackward0] + 140210811942224 -> 140210811941552 + 140210811942224 [label=ToCopyBackward0] + 140210811942416 -> 140210811942224 + 140202229014336 [label="encoder.layer.4.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202229014336 -> 140210811942416 + 140210811942416 [label=AccumulateGrad] + 140210811941456 -> 140210811941312 + 140210811941264 -> 140210811941216 + 140202229014096 [label="encoder.layer.4.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229014096 -> 140210811941264 + 140210811941264 [label=AccumulateGrad] + 140210811940976 -> 140210811941216 + 140202229013776 [label="encoder.layer.4.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229013776 -> 140210811940976 + 140210811940976 [label=AccumulateGrad] + 140210811927712 -> 140210811928192 + 140210811927712 [label=TBackward0] + 140210811928432 -> 140210811927712 + 140210811928432 [label=ToCopyBackward0] + 140210811941168 -> 140210811928432 + 140202229013856 [label="encoder.layer.4.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202229013856 -> 140210811941168 + 140210811941168 [label=AccumulateGrad] + 140210811927616 -> 140210811927568 + 140210811927616 [label=ReshapeAliasBackward0] + 140210811927952 -> 140210811927616 + 140210811927952 [label=ExpandBackward0] + 140210811928144 -> 140210811927952 + 140210811928144 [label=TransposeBackward0] + 140210811928528 -> 140210811928144 + 140210811928528 [label=PermuteBackward0] + 140210811927760 -> 140210811928528 + 140210811927760 [label=ViewBackward0] + 140210811940928 -> 140210811927760 + 140210811940928 [label=ViewBackward0] + 140210811941696 -> 140210811940928 + 140210811941696 [label=AddmmBackward0] + 140210811941936 -> 140210811941696 + 140210811941936 [label=ToCopyBackward0] + 140210811942128 -> 140210811941936 + 140202229013296 [label="encoder.layer.4.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140202229013296 -> 140210811942128 + 140210811942128 [label=AccumulateGrad] + 140210811941888 -> 140210811941696 + 140210811941888 [label=ViewBackward0] + 140210811942944 -> 140210811941888 + 140210811942944 [label=ToCopyBackward0] + 140210812052960 -> 140210811942944 + 140210811941072 -> 140210811941696 + 140210811941072 [label=TBackward0] + 140210811942752 -> 140210811941072 + 140210811942752 [label=ToCopyBackward0] + 140210811943664 -> 140210811942752 + 140202229013616 [label="encoder.layer.4.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140202229013616 -> 140210811943664 + 140210811943664 [label=AccumulateGrad] + 140210811926704 -> 140210811926656 + 140210811926704 [label=ReshapeAliasBackward0] + 140210811927040 -> 140210811926704 + 140210811927040 [label=ExpandBackward0] + 140210811927232 -> 140210811927040 + 140210811927232 [label=PermuteBackward0] + 140210811927424 -> 140210811927232 + 140210811927424 [label=ViewBackward0] + 140210811926800 -> 140210811927424 + 140210811926800 [label=ViewBackward0] + 140210811928048 -> 140210811926800 + 140210811928048 [label=AddmmBackward0] + 140210811958464 -> 140210811928048 + 140210811958464 [label=ToCopyBackward0] + 140210811942320 -> 140210811958464 + 140202229013056 [label="encoder.layer.4.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140202229013056 -> 140210811942320 + 140210811942320 [label=AccumulateGrad] + 140210811928336 -> 140210811928048 + 140210811928336 [label=ViewBackward0] + 140210811943328 -> 140210811928336 + 140210811943328 [label=ToCopyBackward0] + 140210812052960 -> 140210811943328 + 140210811926848 -> 140210811928048 + 140210811926848 [label=TBackward0] + 140210811941360 -> 140210811926848 + 140210811941360 [label=ToCopyBackward0] + 140210811941600 -> 140210811941360 + 140202229013376 [label="encoder.layer.4.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140202229013376 -> 140210811941600 + 140210811941600 [label=AccumulateGrad] + 140210811925744 -> 140210811925936 + 140210811925744 [label=TBackward0] + 140210811926416 -> 140210811925744 + 140210811926416 [label=ToCopyBackward0] + 140210811926608 -> 140210811926416 + 140202229013136 [label="encoder.layer.4.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202229013136 -> 140210811926608 + 140210811926608 [label=AccumulateGrad] + 140210811925648 -> 140210811925504 + 140210811925456 -> 140210811925408 + 140202229012896 [label="encoder.layer.4.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229012896 -> 140210811925456 + 140210811925456 [label=AccumulateGrad] + 140210811925024 -> 140210811925408 + 140202229012576 [label="encoder.layer.4.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229012576 -> 140210811925024 + 140210811925024 [label=AccumulateGrad] + 140210811924592 -> 140210811924832 + 140210811924592 [label=TBackward0] + 140210811925072 -> 140210811924592 + 140210811925072 [label=ToCopyBackward0] + 140210811925552 -> 140210811925072 + 140202229011216 [label="encoder.layer.4.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202229011216 -> 140210811925552 + 140210811925552 [label=AccumulateGrad] + 140202224193104 -> 140202224193296 + 140202224193104 [label=TBackward0] + 140210811924784 -> 140202224193104 + 140210811924784 [label=ToCopyBackward0] + 140210811925264 -> 140210811924784 + 140202229010976 [label="encoder.layer.4.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202229010976 -> 140210811925264 + 140210811925264 [label=AccumulateGrad] + 140202224193008 -> 140202224192864 + 140202224192816 -> 140202224192720 + 140202229010736 [label="encoder.layer.4.expert_ln.weight + (768)" fillcolor=lightblue] + 140202229010736 -> 140202224192816 + 140202224192816 [label=AccumulateGrad] + 140202224192768 -> 140202224192720 + 140202229010496 [label="encoder.layer.4.expert_ln.bias + (768)" fillcolor=lightblue] + 140202229010496 -> 140202224192768 + 140202224192768 [label=AccumulateGrad] + 140202224192432 -> 140202224191472 + 140202224192432 [label=NativeLayerNormBackward0] + 140202224193152 -> 140202224192432 + 140202224193152 [label=AddBackward0] + 140202224193440 -> 140202224193152 + 140202224193440 [label=NativeDropoutBackward0] + 140210811924688 -> 140202224193440 + 140210811924688 [label=ViewBackward0] + 140210811925216 -> 140210811924688 + 140210811925216 [label=AddmmBackward0] + 140210811926080 -> 140210811925216 + 140210811926080 [label=ToCopyBackward0] + 140210811926176 -> 140210811926080 + 140202229012096 [label="encoder.layer.4.output.dense.bias + (768)" fillcolor=lightblue] + 140202229012096 -> 140210811926176 + 140210811926176 [label=AccumulateGrad] + 140210811925888 -> 140210811925216 + 140210811925888 [label=ViewBackward0] + 140210811926320 -> 140210811925888 + 140210811926320 [label=GeluBackward0] + 140210811927328 -> 140210811926320 + 140210811927328 [label=ViewBackward0] + 140210811927856 -> 140210811927328 + 140210811927856 [label=AddmmBackward0] + 140210811926944 -> 140210811927856 + 140210811926944 [label=ToCopyBackward0] + 140210812095024 -> 140210811926944 + 140202229012336 [label="encoder.layer.4.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202229012336 -> 140210812095024 + 140210812095024 [label=AccumulateGrad] + 140210811943136 -> 140210811927856 + 140210811943136 [label=ViewBackward0] + 140210812095120 -> 140210811943136 + 140210812095120 [label=ToCopyBackward0] + 140210811924976 -> 140210812095120 + 140210811924976 [label=SliceBackward0] + 140210812095264 -> 140210811924976 + 140210812095264 [label=SliceBackward0] + 140210812095360 -> 140210812095264 + 140210812095360 [label=SliceBackward0] + 140210811941216 -> 140210812095360 + 140210811941408 -> 140210811927856 + 140210811941408 [label=TBackward0] + 140210812094928 -> 140210811941408 + 140210812094928 [label=ToCopyBackward0] + 140210812095456 -> 140210812094928 + 140202229012656 [label="encoder.layer.4.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202229012656 -> 140210812095456 + 140210812095456 [label=AccumulateGrad] + 140210811925792 -> 140210811925216 + 140210811925792 [label=TBackward0] + 140210811927520 -> 140210811925792 + 140210811927520 [label=ToCopyBackward0] + 140210811941984 -> 140210811927520 + 140202229012416 [label="encoder.layer.4.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202229012416 -> 140210811941984 + 140210811941984 [label=AccumulateGrad] + 140210811924976 -> 140202224193152 + 140202224192960 -> 140202224192432 + 140202229012176 [label="encoder.layer.4.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202229012176 -> 140202224192960 + 140202224192960 [label=AccumulateGrad] + 140202224192912 -> 140202224192432 + 140202229011856 [label="encoder.layer.4.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202229011856 -> 140202224192912 + 140202224192912 [label=AccumulateGrad] + 140202224191712 -> 140202224192192 + 140202224191712 [label=TBackward0] + 140202224192528 -> 140202224191712 + 140202224192528 [label=ToCopyBackward0] + 140202224193248 -> 140202224192528 + 140202228989840 [label="encoder.layer.5.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228989840 -> 140202224193248 + 140202224193248 [label=AccumulateGrad] + 140202224189600 -> 140202224189552 + 140202224189600 [label=ReshapeAliasBackward0] + 140202224191952 -> 140202224189600 + 140202224191952 [label=ExpandBackward0] + 140202224192144 -> 140202224191952 + 140202224192144 [label=TransposeBackward0] + 140202224192672 -> 140202224192144 + 140202224192672 [label=PermuteBackward0] + 140202224192624 -> 140202224192672 + 140202224192624 [label=ViewBackward0] + 140210811924544 -> 140202224192624 + 140210811924544 [label=ViewBackward0] + 140210811926512 -> 140210811924544 + 140210811926512 [label=AddmmBackward0] + 140210811927136 -> 140210811926512 + 140210811927136 [label=ToCopyBackward0] + 140210812095072 -> 140210811927136 + 140202228989360 [label="encoder.layer.5.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202228989360 -> 140210812095072 + 140210812095072 [label=AccumulateGrad] + 140210811925600 -> 140210811926512 + 140210811925600 [label=ViewBackward0] + 140210812095504 -> 140210811925600 + 140210812095504 [label=ToCopyBackward0] + 140202224191472 -> 140210812095504 + 140210812095168 -> 140210811926512 + 140210812095168 [label=TBackward0] + 140210812095216 -> 140210812095168 + 140210812095216 [label=ToCopyBackward0] + 140210812095648 -> 140210812095216 + 140202228989680 [label="encoder.layer.5.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202228989680 -> 140210812095648 + 140210812095648 [label=AccumulateGrad] + 140202224190416 -> 140202224190560 + 140202224190416 [label=ReshapeAliasBackward0] + 140202224190176 -> 140202224190416 + 140202224190176 [label=ExpandBackward0] + 140202224189984 -> 140202224190176 + 140202224189984 [label=PermuteBackward0] + 140202224189792 -> 140202224189984 + 140202224189792 [label=ViewBackward0] + 140202224190320 -> 140202224189792 + 140202224190320 [label=ViewBackward0] + 140202224192048 -> 140202224190320 + 140202224192048 [label=AddmmBackward0] + 140202224191760 -> 140202224192048 + 140202224191760 [label=ToCopyBackward0] + 140210811926128 -> 140202224191760 + 140202228989120 [label="encoder.layer.5.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202228989120 -> 140210811926128 + 140210811926128 [label=AccumulateGrad] + 140202224192336 -> 140202224192048 + 140202224192336 [label=ViewBackward0] + 140210812095408 -> 140202224192336 + 140210812095408 [label=ToCopyBackward0] + 140202224191472 -> 140210812095408 + 140202224190368 -> 140202224192048 + 140202224190368 [label=TBackward0] + 140210812095312 -> 140202224190368 + 140210812095312 [label=ToCopyBackward0] + 140210812095552 -> 140210812095312 + 140202228989440 [label="encoder.layer.5.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202228989440 -> 140210812095552 + 140210812095552 [label=AccumulateGrad] + 140202224191376 -> 140202224191184 + 140202224191376 [label=TBackward0] + 140202224190704 -> 140202224191376 + 140202224190704 [label=ToCopyBackward0] + 140202224190512 -> 140202224190704 + 140202228989200 [label="encoder.layer.5.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228989200 -> 140202224190512 + 140202224190512 [label=AccumulateGrad] + 140202224191472 -> 140202222987584 + 140202222987152 -> 140202222988352 + 140202228988960 [label="encoder.layer.5.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228988960 -> 140202222987152 + 140202222987152 [label=AccumulateGrad] + 140202222987104 -> 140202222988352 + 140202228988640 [label="encoder.layer.5.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228988640 -> 140202222987104 + 140202222987104 [label=AccumulateGrad] + 140202222986192 -> 140202222986672 + 140202222986192 [label=TBackward0] + 140202222987200 -> 140202222986192 + 140202222987200 [label=ToCopyBackward0] + 140202222988544 -> 140202222987200 + 140202228987280 [label="encoder.layer.5.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228987280 -> 140202222988544 + 140202222988544 [label=AccumulateGrad] + 140202222985760 -> 140202222986048 + 140202222985760 [label=TBackward0] + 140202222986720 -> 140202222985760 + 140202222986720 [label=ToCopyBackward0] + 140202222987488 -> 140202222986720 + 140202228987040 [label="encoder.layer.5.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228987040 -> 140202222987488 + 140202222987488 [label=AccumulateGrad] + 140202222985568 -> 140202222985280 + 140202222988592 -> 140202222988736 + 140202228986800 [label="encoder.layer.5.expert_ln.weight + (768)" fillcolor=lightblue] + 140202228986800 -> 140202222988592 + 140202222988592 [label=AccumulateGrad] + 140202222987392 -> 140202222988736 + 140202228986480 [label="encoder.layer.5.expert_ln.bias + (768)" fillcolor=lightblue] + 140202228986480 -> 140202222987392 + 140202222987392 [label=AccumulateGrad] + 140202222988160 -> 140202222935248 + 140202222988160 [label=NativeLayerNormBackward0] + 140202222985664 -> 140202222988160 + 140202222985664 [label=AddBackward0] + 140202222987008 -> 140202222985664 + 140202222987008 [label=NativeDropoutBackward0] + 140202222986528 -> 140202222987008 + 140202222986528 [label=ViewBackward0] + 140202222988448 -> 140202222986528 + 140202222988448 [label=AddmmBackward0] + 140202222988256 -> 140202222988448 + 140202222988256 [label=ToCopyBackward0] + 140202224190992 -> 140202222988256 + 140202228988160 [label="encoder.layer.5.output.dense.bias + (768)" fillcolor=lightblue] + 140202228988160 -> 140202224190992 + 140202224190992 [label=AccumulateGrad] + 140202224191616 -> 140202222988448 + 140202224191616 [label=ViewBackward0] + 140202224190608 -> 140202224191616 + 140202224190608 [label=GeluBackward0] + 140202224191040 -> 140202224190608 + 140202224191040 [label=ViewBackward0] + 140202224190080 -> 140202224191040 + 140202224190080 [label=AddmmBackward0] + 140202224189696 -> 140202224190080 + 140202224189696 [label=ToCopyBackward0] + 140210811925360 -> 140202224189696 + 140202228988400 [label="encoder.layer.5.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202228988400 -> 140210811925360 + 140210811925360 [label=AccumulateGrad] + 140202224189888 -> 140202224190080 + 140202224189888 [label=ViewBackward0] + 140210812095840 -> 140202224189888 + 140210812095840 [label=ToCopyBackward0] + 140202222986336 -> 140210812095840 + 140202222986336 [label=SliceBackward0] + 140210812095888 -> 140202222986336 + 140210812095888 [label=SliceBackward0] + 140210812095984 -> 140210812095888 + 140210812095984 [label=SliceBackward0] + 140202222988352 -> 140210812095984 + 140202224191136 -> 140202224190080 + 140202224191136 [label=TBackward0] + 140210812095600 -> 140202224191136 + 140210812095600 [label=ToCopyBackward0] + 140210812096080 -> 140210812095600 + 140202228988720 [label="encoder.layer.5.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202228988720 -> 140210812096080 + 140210812096080 [label=AccumulateGrad] + 140202224191568 -> 140202222988448 + 140202224191568 [label=TBackward0] + 140202224190800 -> 140202224191568 + 140202224190800 [label=ToCopyBackward0] + 140202224191856 -> 140202224190800 + 140202228988480 [label="encoder.layer.5.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202228988480 -> 140202224191856 + 140202224191856 [label=AccumulateGrad] + 140202222986336 -> 140202222985664 + 140202222985376 -> 140202222988160 + 140202228988240 [label="encoder.layer.5.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228988240 -> 140202222985376 + 140202222985376 [label=AccumulateGrad] + 140202222985328 -> 140202222988160 + 140202228987920 [label="encoder.layer.5.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228987920 -> 140202222985328 + 140202222985328 [label=AccumulateGrad] + 140202222963584 -> 140202222964352 + 140202222963584 [label=TBackward0] + 140202222988928 -> 140202222963584 + 140202222988928 [label=ToCopyBackward0] + 140202222986144 -> 140202222988928 + 140202228986560 [label="encoder.layer.6.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228986560 -> 140202222986144 + 140202222986144 [label=AccumulateGrad] + 140202222963392 -> 140202222963200 + 140202222963392 [label=ReshapeAliasBackward0] + 140202222963728 -> 140202222963392 + 140202222963728 [label=ExpandBackward0] + 140202222964160 -> 140202222963728 + 140202222964160 [label=TransposeBackward0] + 140202222964448 -> 140202222964160 + 140202222964448 [label=PermuteBackward0] + 140202222989120 -> 140202222964448 + 140202222989120 [label=ViewBackward0] + 140202222988112 -> 140202222989120 + 140202222988112 [label=ViewBackward0] + 140202222988640 -> 140202222988112 + 140202222988640 [label=AddmmBackward0] + 140202224190272 -> 140202222988640 + 140202224190272 [label=ToCopyBackward0] + 140210812095792 -> 140202224190272 + 140202228986000 [label="encoder.layer.6.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202228986000 -> 140210812095792 + 140210812095792 [label=AccumulateGrad] + 140202224191424 -> 140202222988640 + 140202224191424 [label=ViewBackward0] + 140210812096128 -> 140202224191424 + 140210812096128 [label=ToCopyBackward0] + 140202222935248 -> 140210812096128 + 140210812095696 -> 140202222988640 + 140210812095696 [label=TBackward0] + 140210812095744 -> 140210812095696 + 140210812095744 [label=ToCopyBackward0] + 140210812096272 -> 140210812095744 + 140202228986320 [label="encoder.layer.6.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202228986320 -> 140210812096272 + 140210812096272 [label=AccumulateGrad] + 140202222961760 -> 140202222961856 + 140202222961760 [label=ReshapeAliasBackward0] + 140202222962432 -> 140202222961760 + 140202222962432 [label=ExpandBackward0] + 140202222962816 -> 140202222962432 + 140202222962816 [label=PermuteBackward0] + 140202222963104 -> 140202222962816 + 140202222963104 [label=ViewBackward0] + 140202222961808 -> 140202222963104 + 140202222961808 [label=ViewBackward0] + 140202222963968 -> 140202222961808 + 140202222963968 [label=AddmmBackward0] + 140202222963488 -> 140202222963968 + 140202222963488 [label=ToCopyBackward0] + 140202224191328 -> 140202222963488 + 140202228985664 [label="encoder.layer.6.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202228985664 -> 140202224191328 + 140202224191328 [label=AccumulateGrad] + 140202222962144 -> 140202222963968 + 140202222962144 [label=ViewBackward0] + 140210812096032 -> 140202222962144 + 140210812096032 [label=ToCopyBackward0] + 140202222935248 -> 140210812096032 + 140202222985712 -> 140202222963968 + 140202222985712 [label=TBackward0] + 140210812095936 -> 140202222985712 + 140210812095936 [label=ToCopyBackward0] + 140210812096176 -> 140210812095936 + 140202228986080 [label="encoder.layer.6.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202228986080 -> 140210812096176 + 140210812096176 [label=AccumulateGrad] + 140202222960704 -> 140202222935728 + 140202222960704 [label=TBackward0] + 140202222961280 -> 140202222960704 + 140202222961280 [label=ToCopyBackward0] + 140202222961568 -> 140202222961280 + 140202228985744 [label="encoder.layer.6.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228985744 -> 140202222961568 + 140202222961568 [label=AccumulateGrad] + 140202222935248 -> 140202222935296 + 140202222935008 -> 140202222935104 + 140202228985504 [label="encoder.layer.6.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228985504 -> 140202222935008 + 140202222935008 [label=AccumulateGrad] + 140202222934336 -> 140202222935104 + 140202228985184 [label="encoder.layer.6.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228985184 -> 140202222934336 + 140202222934336 [label=AccumulateGrad] + 140202222933184 -> 140202222933952 + 140202222933184 [label=TBackward0] + 140202222934240 -> 140202222933184 + 140202222934240 [label=ToCopyBackward0] + 140202222934768 -> 140202222934240 + 140202228985264 [label="encoder.layer.6.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228985264 -> 140202222934768 + 140202222934768 [label=AccumulateGrad] + 140202222932992 -> 140202222932800 + 140202222932992 [label=ReshapeAliasBackward0] + 140202222933328 -> 140202222932992 + 140202222933328 [label=ExpandBackward0] + 140202222933760 -> 140202222933328 + 140202222933760 [label=TransposeBackward0] + 140202222934528 -> 140202222933760 + 140202222934528 [label=PermuteBackward0] + 140202222935392 -> 140202222934528 + 140202222935392 [label=ViewBackward0] + 140202222934624 -> 140202222935392 + 140202222934624 [label=ViewBackward0] + 140202222935872 -> 140202222934624 + 140202222935872 [label=AddmmBackward0] + 140202222933088 -> 140202222935872 + 140202222933088 [label=ToCopyBackward0] + 140202222961088 -> 140202222933088 + 140202228984704 [label="encoder.layer.6.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140202228984704 -> 140202222961088 + 140202222961088 [label=AccumulateGrad] + 140202222960800 -> 140202222935872 + 140202222960800 [label=ViewBackward0] + 140202222962624 -> 140202222960800 + 140202222962624 [label=ToCopyBackward0] + 140210812052960 -> 140202222962624 + 140202222960896 -> 140202222935872 + 140202222960896 [label=TBackward0] + 140202222962336 -> 140202222960896 + 140202222962336 [label=ToCopyBackward0] + 140202222963680 -> 140202222962336 + 140202228985024 [label="encoder.layer.6.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140202228985024 -> 140202222963680 + 140202222963680 [label=AccumulateGrad] + 140202222906720 -> 140202222906816 + 140202222906720 [label=ReshapeAliasBackward0] + 140202222906768 -> 140202222906720 + 140202222906768 [label=ExpandBackward0] + 140202222907104 -> 140202222906768 + 140202222907104 [label=PermuteBackward0] + 140202222932704 -> 140202222907104 + 140202222932704 [label=ViewBackward0] + 140202222932032 -> 140202222932704 + 140202222932032 [label=ViewBackward0] + 140202222933568 -> 140202222932032 + 140202222933568 [label=AddmmBackward0] + 140202222934720 -> 140202222933568 + 140202222934720 [label=ToCopyBackward0] + 140202222987776 -> 140202222934720 + 140202228984464 [label="encoder.layer.6.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140202228984464 -> 140202222987776 + 140202222987776 [label=AccumulateGrad] + 140202222934048 -> 140202222933568 + 140202222934048 [label=ViewBackward0] + 140202222935776 -> 140202222934048 + 140202222935776 [label=ToCopyBackward0] + 140210812052960 -> 140202222935776 + 140202222932224 -> 140202222933568 + 140202222932224 [label=TBackward0] + 140202222961328 -> 140202222932224 + 140202222961328 [label=ToCopyBackward0] + 140202222960992 -> 140202222961328 + 140202228984784 [label="encoder.layer.6.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140202228984784 -> 140202222960992 + 140202222960992 [label=AccumulateGrad] + 140202222905088 -> 140202222905328 + 140202222905088 [label=TBackward0] + 140202222906240 -> 140202222905088 + 140202222906240 [label=ToCopyBackward0] + 140202222906528 -> 140202222906240 + 140202228984544 [label="encoder.layer.6.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228984544 -> 140202222906528 + 140202222906528 [label=AccumulateGrad] + 140202222904848 -> 140202222904896 + 140202222904704 -> 140202222904512 + 140202228984304 [label="encoder.layer.6.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228984304 -> 140202222904704 + 140202222904704 [label=AccumulateGrad] + 140202222903936 -> 140202222904512 + 140202228983984 [label="encoder.layer.6.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228983984 -> 140202222903936 + 140202222903936 [label=AccumulateGrad] + 140202222903408 -> 140202222903552 + 140202222903408 [label=TBackward0] + 140202222904128 -> 140202222903408 + 140202222904128 [label=ToCopyBackward0] + 140202222904800 -> 140202222904128 + 140202228968480 [label="encoder.layer.6.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228968480 -> 140202222904800 + 140202222904800 [label=AccumulateGrad] + 140202222873664 -> 140202222874144 + 140202222873664 [label=TBackward0] + 140202222874336 -> 140202222873664 + 140202222874336 [label=ToCopyBackward0] + 140202222904368 -> 140202222874336 + 140202228968560 [label="encoder.layer.6.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228968560 -> 140202222904368 + 140202222904368 [label=AccumulateGrad] + 140202222873280 -> 140202222873232 + 140202222873280 [label=UnsqueezeBackward0] + 140202222873856 -> 140202222873280 + 140202222873856 [label=NativeDropoutBackward0] + 140202222874240 -> 140202222873856 + 140202222874240 [label=ViewBackward0] + 140202222905376 -> 140202222874240 + 140202222905376 [label=AddmmBackward0] + 140202222903360 -> 140202222905376 + 140202222903360 [label=ToCopyBackward0] + 140202222905856 -> 140202222903360 + 140202228968240 [label="encoder.layer.6.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140202228968240 -> 140202222905856 + 140202222905856 [label=AccumulateGrad] + 140202222904608 -> 140202222905376 + 140202222904608 [label=ViewBackward0] + 140202222905760 -> 140202222904608 + 140202222905760 [label=GeluBackward0] + 140202222907296 -> 140202222905760 + 140202222907296 [label=ViewBackward0] + 140202222906048 -> 140202222907296 + 140202222906048 [label=AddmmBackward0] + 140202222905472 -> 140202222906048 + 140202222905472 [label=ToCopyBackward0] + 140202222935200 -> 140202222905472 + 140202228969040 [label="encoder.layer.6.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140202228969040 -> 140202222935200 + 140202222935200 [label=AccumulateGrad] + 140202222932512 -> 140202222906048 + 140202222932512 [label=ViewBackward0] + 140202222962912 -> 140202222932512 + 140202222962912 [label=ToCopyBackward0] + 140202222872416 -> 140202222962912 + 140202222932416 -> 140202222906048 + 140202222932416 [label=TBackward0] + 140202222933280 -> 140202222932416 + 140202222933280 [label=ToCopyBackward0] + 140210812096320 -> 140202222933280 + 140202228968320 [label="encoder.layer.6.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228968320 -> 140210812096320 + 140210812096320 [label=AccumulateGrad] + 140202222903648 -> 140202222905376 + 140202222903648 [label=TBackward0] + 140202222905952 -> 140202222903648 + 140202222905952 [label=ToCopyBackward0] + 140202222963296 -> 140202222905952 + 140202228968080 [label="encoder.layer.6.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228968080 -> 140202222963296 + 140202222963296 [label=AccumulateGrad] + 140202222873184 -> 140202222873232 + 140202222873184 [label=UnsqueezeBackward0] + 140202222932896 -> 140202222873184 + 140202222932896 [label=NativeDropoutBackward0] + 140202222873760 -> 140202222932896 + 140202222873760 [label=ViewBackward0] + 140202222906288 -> 140202222873760 + 140202222906288 [label=AddmmBackward0] + 140202222903888 -> 140202222906288 + 140202222903888 [label=ToCopyBackward0] + 140210812096224 -> 140202222903888 + 140202228967760 [label="encoder.layer.6.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140202228967760 -> 140210812096224 + 140210812096224 [label=AccumulateGrad] + 140210812096464 -> 140202222906288 + 140210812096464 [label=ViewBackward0] + 140210811723936 -> 140210812096464 + 140210811723936 [label=GeluBackward0] + 140210811724032 -> 140210811723936 + 140210811724032 [label=ViewBackward0] + 140210811724128 -> 140210811724032 + 140210811724128 [label=AddmmBackward0] + 140210811724224 -> 140210811724128 + 140210811724224 [label=ToCopyBackward0] + 140210811724416 -> 140210811724224 + 140202228968000 [label="encoder.layer.6.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140202228968000 -> 140210811724416 + 140210811724416 [label=AccumulateGrad] + 140210811724176 -> 140210811724128 + 140210811724176 [label=ViewBackward0] + 140210811724464 -> 140210811724176 + 140210811724464 [label=ToCopyBackward0] + 140202222872416 -> 140210811724464 + 140210811723888 -> 140210811724128 + 140210811723888 [label=TBackward0] + 140210811724320 -> 140210811723888 + 140210811724320 [label=ToCopyBackward0] + 140210811724608 -> 140210811724320 + 140202228967840 [label="encoder.layer.6.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228967840 -> 140210811724608 + 140210811724608 [label=AccumulateGrad] + 140210812096368 -> 140202222906288 + 140210812096368 [label=TBackward0] + 140210811724080 -> 140210812096368 + 140210811724080 [label=ToCopyBackward0] + 140210811724560 -> 140210811724080 + 140202228967600 [label="encoder.layer.6.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228967600 -> 140210811724560 + 140210811724560 [label=AccumulateGrad] + 140202222872992 -> 140202222872800 + 140202222872992 [label=UnsqueezeBackward0] + 140202222873568 -> 140202222872992 + 140202222873568 [label=UnsqueezeBackward0] + 140202222904224 -> 140202222873568 + 140202222904224 [label=SumBackward1] + 140210812096416 -> 140202222904224 + 140210812096416 [label=MulBackward0] + 140210811724704 -> 140210812096416 + 140210811724704 [label=UnsqueezeBackward0] + 140210811723984 -> 140210811724704 + 140210811723984 [label=TopkBackward0] + 140210811724512 -> 140210811723984 + 140210811724512 [label=SoftmaxBackward0] + 140210811724800 -> 140210811724512 + 140210811724800 [label=MmBackward0] + 140210811724896 -> 140210811724800 + 140210811724896 [label=ToCopyBackward0] + 140210811725040 -> 140210811724896 + 140210811725040 [label=MeanBackward1] + 140210811725136 -> 140210811725040 + 140210811725136 [label=MulBackward0] + 140202222872416 -> 140210811725136 + 140210811724848 -> 140210811724800 + 140210811724848 [label=TBackward0] + 140210811725232 -> 140210811724848 + 140210811725232 [label=ToCopyBackward0] + 140210811724944 -> 140210811725232 + 140202228981824 [label="encoder.layer.6.experts.gate.weight + (3, 768)" fillcolor=lightblue] + 140202228981824 -> 140210811724944 + 140210811724944 [label=AccumulateGrad] + 140202222872416 -> 140202222872272 + 140202222872128 -> 140202222871936 + 140202228982144 [label="encoder.layer.6.expert_ln.weight + (768)" fillcolor=lightblue] + 140202228982144 -> 140202222872128 + 140202222872128 [label=AccumulateGrad] + 140202222872224 -> 140202222871936 + 140202228981904 [label="encoder.layer.6.expert_ln.bias + (768)" fillcolor=lightblue] + 140202228981904 -> 140202222872224 + 140202222872224 [label=AccumulateGrad] + 140202222871744 -> 140202222842352 + 140202222871744 [label=NativeLayerNormBackward0] + 140202222904992 -> 140202222871744 + 140202222904992 [label=AddBackward0] + 140202222873088 -> 140202222904992 + 140202222873088 [label=NativeDropoutBackward0] + 140210811724656 -> 140202222873088 + 140210811724656 [label=ViewBackward0] + 140210811723840 -> 140210811724656 + 140210811723840 [label=AddmmBackward0] + 140210811725184 -> 140210811723840 + 140210811725184 [label=ToCopyBackward0] + 140210811725376 -> 140210811725184 + 140202228983504 [label="encoder.layer.6.output.dense.bias + (768)" fillcolor=lightblue] + 140202228983504 -> 140210811725376 + 140210811725376 [label=AccumulateGrad] + 140210811725088 -> 140210811723840 + 140210811725088 [label=ViewBackward0] + 140210811725424 -> 140210811725088 + 140210811725424 [label=GeluBackward0] + 140210811725520 -> 140210811725424 + 140210811725520 [label=ViewBackward0] + 140210811725616 -> 140210811725520 + 140210811725616 [label=AddmmBackward0] + 140210811725712 -> 140210811725616 + 140210811725712 [label=ToCopyBackward0] + 140210811725904 -> 140210811725712 + 140202228983744 [label="encoder.layer.6.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202228983744 -> 140210811725904 + 140210811725904 [label=AccumulateGrad] + 140210811725664 -> 140210811725616 + 140210811725664 [label=ViewBackward0] + 140210811725952 -> 140210811725664 + 140210811725952 [label=ToCopyBackward0] + 140202222873376 -> 140210811725952 + 140202222873376 [label=SliceBackward0] + 140210811726096 -> 140202222873376 + 140210811726096 [label=SliceBackward0] + 140210811726192 -> 140210811726096 + 140210811726192 [label=SliceBackward0] + 140202222935104 -> 140210811726192 + 140210811724992 -> 140210811725616 + 140210811724992 [label=TBackward0] + 140210811725856 -> 140210811724992 + 140210811725856 [label=ToCopyBackward0] + 140210811726288 -> 140210811725856 + 140202228984064 [label="encoder.layer.6.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202228984064 -> 140210811726288 + 140210811726288 [label=AccumulateGrad] + 140210811724272 -> 140210811723840 + 140210811724272 [label=TBackward0] + 140210811725568 -> 140210811724272 + 140210811725568 [label=ToCopyBackward0] + 140210811726048 -> 140210811725568 + 140202228983824 [label="encoder.layer.6.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202228983824 -> 140210811726048 + 140210811726048 [label=AccumulateGrad] + 140202222873376 -> 140202222904992 + 140202222872512 -> 140202222871744 + 140202228983584 [label="encoder.layer.6.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228983584 -> 140202222872512 + 140202222872512 [label=AccumulateGrad] + 140202222872320 -> 140202222871744 + 140202228983264 [label="encoder.layer.6.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228983264 -> 140202222872320 + 140202222872320 [label=AccumulateGrad] + 140202222870592 -> 140202222871168 + 140202222870592 [label=TBackward0] + 140202222871456 -> 140202222870592 + 140202222871456 [label=ToCopyBackward0] + 140202222872608 -> 140202222871456 + 140202228982384 [label="encoder.layer.7.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228982384 -> 140202222872608 + 140202222872608 [label=AccumulateGrad] + 140202222845568 -> 140202222845664 + 140202222845568 [label=ReshapeAliasBackward0] + 140202222845760 -> 140202222845568 + 140202222845760 [label=ExpandBackward0] + 140202222871264 -> 140202222845760 + 140202222871264 [label=TransposeBackward0] + 140202222872032 -> 140202222871264 + 140202222872032 [label=PermuteBackward0] + 140202222871840 -> 140202222872032 + 140202222871840 [label=ViewBackward0] + 140202222870784 -> 140202222871840 + 140202222870784 [label=ViewBackward0] + 140210811725280 -> 140202222870784 + 140210811725280 [label=AddmmBackward0] + 140210811725808 -> 140210811725280 + 140210811725808 [label=ToCopyBackward0] + 140210811726000 -> 140210811725808 + 140203184706720 [label="encoder.layer.7.attention.self.key.bias + (768)" fillcolor=lightblue] + 140203184706720 -> 140210811726000 + 140210811726000 [label=AccumulateGrad] + 140210811725760 -> 140210811725280 + 140210811725760 [label=ViewBackward0] + 140210811726336 -> 140210811725760 + 140210811726336 [label=ToCopyBackward0] + 140202222842352 -> 140210811726336 + 140210811724752 -> 140210811725280 + 140210811724752 [label=TBackward0] + 140210811725472 -> 140210811724752 + 140210811725472 [label=ToCopyBackward0] + 140210811726480 -> 140210811725472 + 140202228982624 [label="encoder.layer.7.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202228982624 -> 140210811726480 + 140210811726480 [label=AccumulateGrad] + 140202222844224 -> 140202222843936 + 140202222844224 [label=ReshapeAliasBackward0] + 140202222844608 -> 140202222844224 + 140202222844608 [label=ExpandBackward0] + 140202222844896 -> 140202222844608 + 140202222844896 [label=PermuteBackward0] + 140202222845280 -> 140202222844896 + 140202222845280 [label=ViewBackward0] + 140202222844272 -> 140202222845280 + 140202222844272 [label=ViewBackward0] + 140202222844320 -> 140202222844272 + 140202222844320 [label=AddmmBackward0] + 140202222872752 -> 140202222844320 + 140202222872752 [label=ToCopyBackward0] + 140210811726432 -> 140202222872752 + 140202228969200 [label="encoder.layer.7.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202228969200 -> 140210811726432 + 140210811726432 [label=AccumulateGrad] + 140202222871552 -> 140202222844320 + 140202222871552 [label=ViewBackward0] + 140210811726240 -> 140202222871552 + 140210811726240 [label=ToCopyBackward0] + 140202222842352 -> 140210811726240 + 140202222870832 -> 140202222844320 + 140202222870832 [label=TBackward0] + 140210811725328 -> 140202222870832 + 140210811725328 [label=ToCopyBackward0] + 140210811726384 -> 140210811725328 + 140202228969280 [label="encoder.layer.7.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202228969280 -> 140210811726384 + 140210811726384 [label=AccumulateGrad] + 140202222842592 -> 140202222842832 + 140202222842592 [label=TBackward0] + 140202222843744 -> 140202222842592 + 140202222843744 [label=ToCopyBackward0] + 140202222844032 -> 140202222843744 + 140202228968960 [label="encoder.layer.7.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228968960 -> 140202222844032 + 140202222844032 [label=AccumulateGrad] + 140202222842352 -> 140202222841968 + 140202222842112 -> 140202222820224 + 140202228967520 [label="encoder.layer.7.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228967520 -> 140202222842112 + 140202222842112 [label=AccumulateGrad] + 140202222842016 -> 140202222820224 + 140202228967280 [label="encoder.layer.7.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228967280 -> 140202222842016 + 140202222842016 [label=AccumulateGrad] + 140202222819456 -> 140202222819936 + 140202222819456 [label=TBackward0] + 140202222820368 -> 140202222819456 + 140202222820368 [label=ToCopyBackward0] + 140202222821088 -> 140202222820368 + 140202228965600 [label="encoder.layer.7.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228965600 -> 140202222821088 + 140202222821088 [label=AccumulateGrad] + 140202222818880 -> 140202222819168 + 140202222818880 [label=TBackward0] + 140202222819888 -> 140202222818880 + 140202222819888 [label=ToCopyBackward0] + 140202222820800 -> 140202222819888 + 140202228965440 [label="encoder.layer.7.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228965440 -> 140202222820800 + 140202222820800 [label=AccumulateGrad] + 140202222818688 -> 140202222818304 + 140202222818400 -> 140202222818208 + 140202228952736 [label="encoder.layer.7.expert_ln.weight + (768)" fillcolor=lightblue] + 140202228952736 -> 140202222818400 + 140202222818400 [label=AccumulateGrad] + 140202222818112 -> 140202222818208 + 140202228952816 [label="encoder.layer.7.expert_ln.bias + (768)" fillcolor=lightblue] + 140202228952816 -> 140202222818112 + 140202222818112 [label=AccumulateGrad] + 140202222817632 -> 140202223288032 + 140202222817632 [label=NativeLayerNormBackward0] + 140202222818784 -> 140202222817632 + 140202222818784 [label=AddBackward0] + 140202222820320 -> 140202222818784 + 140202222820320 [label=NativeDropoutBackward0] + 140202222819840 -> 140202222820320 + 140202222819840 [label=ViewBackward0] + 140202222820512 -> 140202222819840 + 140202222820512 [label=AddmmBackward0] + 140202222842208 -> 140202222820512 + 140202222842208 [label=ToCopyBackward0] + 140202222843264 -> 140202222842208 + 140202228966880 [label="encoder.layer.7.output.dense.bias + (768)" fillcolor=lightblue] + 140202228966880 -> 140202222843264 + 140202222843264 [label=AccumulateGrad] + 140202222842304 -> 140202222820512 + 140202222842304 [label=ViewBackward0] + 140202222843792 -> 140202222842304 + 140202222843792 [label=GeluBackward0] + 140202222843168 -> 140202222843792 + 140202222843168 [label=ViewBackward0] + 140202222844800 -> 140202222843168 + 140202222844800 [label=AddmmBackward0] + 140202222845376 -> 140202222844800 + 140202222845376 [label=ToCopyBackward0] + 140210811726144 -> 140202222845376 + 140202228967120 [label="encoder.layer.7.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202228967120 -> 140210811726144 + 140210811726144 [label=AccumulateGrad] + 140202222845088 -> 140202222844800 + 140202222845088 [label=ViewBackward0] + 140210811726672 -> 140202222845088 + 140210811726672 [label=ToCopyBackward0] + 140202222819408 -> 140210811726672 + 140202222819408 [label=SliceBackward0] + 140210811726720 -> 140202222819408 + 140210811726720 [label=SliceBackward0] + 140210811726816 -> 140210811726720 + 140210811726816 [label=SliceBackward0] + 140202222820224 -> 140210811726816 + 140202222842976 -> 140202222844800 + 140202222842976 [label=TBackward0] + 140210811726528 -> 140202222842976 + 140210811726528 [label=ToCopyBackward0] + 140210811726912 -> 140210811726528 + 140202228967360 [label="encoder.layer.7.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202228967360 -> 140210811726912 + 140210811726912 [label=AccumulateGrad] + 140202222841920 -> 140202222820512 + 140202222841920 [label=TBackward0] + 140202222843552 -> 140202222841920 + 140202222843552 [label=ToCopyBackward0] + 140202222871072 -> 140202222843552 + 140202228966800 [label="encoder.layer.7.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202228966800 -> 140202222871072 + 140202222871072 [label=AccumulateGrad] + 140202222819408 -> 140202222818784 + 140202222818496 -> 140202222817632 + 140202228966560 [label="encoder.layer.7.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228966560 -> 140202222818496 + 140202222818496 [label=AccumulateGrad] + 140202222818448 -> 140202222817632 + 140202228966640 [label="encoder.layer.7.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228966640 -> 140202222818448 + 140202222818448 [label=AccumulateGrad] + 140202222817440 -> 140202223316800 + 140202222817440 [label=TBackward0] + 140202222817728 -> 140202222817440 + 140202222817728 [label=ToCopyBackward0] + 140202222819264 -> 140202222817728 + 140202228952496 [label="encoder.layer.8.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228952496 -> 140202222819264 + 140202222819264 [label=AccumulateGrad] + 140202223316128 -> 140202223315840 + 140202223316128 [label=ReshapeAliasBackward0] + 140202223316512 -> 140202223316128 + 140202223316512 [label=ExpandBackward0] + 140202223316176 -> 140202223316512 + 140202223316176 [label=TransposeBackward0] + 140202223316224 -> 140202223316176 + 140202223316224 [label=PermuteBackward0] + 140202222821280 -> 140202223316224 + 140202222821280 [label=ViewBackward0] + 140202222817968 -> 140202222821280 + 140202222817968 [label=ViewBackward0] + 140202222817536 -> 140202222817968 + 140202222817536 [label=AddmmBackward0] + 140202222844416 -> 140202222817536 + 140202222844416 [label=ToCopyBackward0] + 140210811726624 -> 140202222844416 + 140202228952336 [label="encoder.layer.8.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202228952336 -> 140210811726624 + 140210811726624 [label=AccumulateGrad] + 140202222842496 -> 140202222817536 + 140202222842496 [label=ViewBackward0] + 140210811726960 -> 140202222842496 + 140210811726960 [label=ToCopyBackward0] + 140202223288032 -> 140210811726960 + 140210811724368 -> 140202222817536 + 140210811724368 [label=TBackward0] + 140210811726576 -> 140210811724368 + 140210811726576 [label=ToCopyBackward0] + 140210811727104 -> 140210811726576 + 140202228952256 [label="encoder.layer.8.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202228952256 -> 140210811727104 + 140210811727104 [label=AccumulateGrad] + 140202223314400 -> 140202223314496 + 140202223314400 [label=ReshapeAliasBackward0] + 140202223315168 -> 140202223314400 + 140202223315168 [label=ExpandBackward0] + 140202223315456 -> 140202223315168 + 140202223315456 [label=PermuteBackward0] + 140202223315696 -> 140202223315456 + 140202223315696 [label=ViewBackward0] + 140202223314592 -> 140202223315696 + 140202223314592 [label=ViewBackward0] + 140202223316704 -> 140202223314592 + 140202223316704 [label=AddmmBackward0] + 140202223314736 -> 140202223316704 + 140202223314736 [label=ToCopyBackward0] + 140202222842688 -> 140202223314736 + 140202228952096 [label="encoder.layer.8.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202228952096 -> 140202222842688 + 140202222842688 [label=AccumulateGrad] + 140202222818976 -> 140202223316704 + 140202222818976 [label=ViewBackward0] + 140210811726864 -> 140202222818976 + 140210811726864 [label=ToCopyBackward0] + 140202223288032 -> 140210811726864 + 140202222818016 -> 140202223316704 + 140202222818016 [label=TBackward0] + 140210811726768 -> 140202222818016 + 140210811726768 [label=ToCopyBackward0] + 140210811727008 -> 140210811726768 + 140202228952016 [label="encoder.layer.8.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202228952016 -> 140210811727008 + 140210811727008 [label=AccumulateGrad] + 140202223313056 -> 140202223313152 + 140202223313056 [label=TBackward0] + 140202223313920 -> 140202223313056 + 140202223313920 [label=ToCopyBackward0] + 140202223314304 -> 140202223313920 + 140202228951776 [label="encoder.layer.8.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228951776 -> 140202223314304 + 140202223314304 [label=AccumulateGrad] + 140202223288032 -> 140202223287936 + 140202223287744 -> 140202223287696 + 140202228951536 [label="encoder.layer.8.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228951536 -> 140202223287744 + 140202223287744 [label=AccumulateGrad] + 140202223286976 -> 140202223287696 + 140202228951616 [label="encoder.layer.8.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228951616 -> 140202223286976 + 140202223286976 [label=AccumulateGrad] + 140202223285776 -> 140202223286688 + 140202223285776 [label=TBackward0] + 140202223286880 -> 140202223285776 + 140202223286880 [label=ToCopyBackward0] + 140202223287552 -> 140202223286880 + 140202228951296 [label="encoder.layer.8.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228951296 -> 140202223287552 + 140202223287552 [label=AccumulateGrad] + 140202223285728 -> 140202223285440 + 140202223285728 [label=ReshapeAliasBackward0] + 140202223286112 -> 140202223285728 + 140202223286112 [label=ExpandBackward0] + 140202223286400 -> 140202223286112 + 140202223286400 [label=TransposeBackward0] + 140202223287264 -> 140202223286400 + 140202223287264 [label=PermuteBackward0] + 140202223288128 -> 140202223287264 + 140202223288128 [label=ViewBackward0] + 140202223287216 -> 140202223288128 + 140202223287216 [label=ViewBackward0] + 140202223285824 -> 140202223287216 + 140202223285824 [label=AddmmBackward0] + 140202223313440 -> 140202223285824 + 140202223313440 [label=ToCopyBackward0] + 140202223313824 -> 140202223313440 + 140202228951136 [label="encoder.layer.8.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140202228951136 -> 140202223313824 + 140202223313824 [label=AccumulateGrad] + 140202223313536 -> 140202223285824 + 140202223313536 [label=ViewBackward0] + 140202223315216 -> 140202223313536 + 140202223315216 [label=ToCopyBackward0] + 140210812052960 -> 140202223315216 + 140202223312960 -> 140202223285824 + 140202223312960 [label=TBackward0] + 140202223314976 -> 140202223312960 + 140202223314976 [label=ToCopyBackward0] + 140202223316320 -> 140202223314976 + 140202228951056 [label="encoder.layer.8.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140202228951056 -> 140202223316320 + 140202223316320 [label=AccumulateGrad] + 140202223251168 -> 140202223251264 + 140202223251168 [label=ReshapeAliasBackward0] + 140202223284768 -> 140202223251168 + 140202223284768 [label=ExpandBackward0] + 140202223285056 -> 140202223284768 + 140202223285056 [label=PermuteBackward0] + 140202223285296 -> 140202223285056 + 140202223285296 [label=ViewBackward0] + 140202223284288 -> 140202223285296 + 140202223284288 [label=ViewBackward0] + 140202223286304 -> 140202223284288 + 140202223286304 [label=AddmmBackward0] + 140202223287360 -> 140202223286304 + 140202223287360 [label=ToCopyBackward0] + 140202222820848 -> 140202223287360 + 140202228950896 [label="encoder.layer.8.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140202228950896 -> 140202222820848 + 140202222820848 [label=AccumulateGrad] + 140202223286784 -> 140202223286304 + 140202223286784 [label=ViewBackward0] + 140202223315936 -> 140202223286784 + 140202223315936 [label=ToCopyBackward0] + 140210812052960 -> 140202223315936 + 140202223284336 -> 140202223286304 + 140202223284336 [label=TBackward0] + 140202223313248 -> 140202223284336 + 140202223313248 [label=ToCopyBackward0] + 140202223314112 -> 140202223313248 + 140202228950816 [label="encoder.layer.8.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140202228950816 -> 140202223314112 + 140202223314112 [label=AccumulateGrad] + 140202223249584 -> 140202223250016 + 140202223249584 [label=TBackward0] + 140202223250784 -> 140202223249584 + 140202223250784 [label=ToCopyBackward0] + 140202223251072 -> 140202223250784 + 140202228950576 [label="encoder.layer.8.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228950576 -> 140202223251072 + 140202223251072 [label=AccumulateGrad] + 140202223249536 -> 140202223249152 + 140202223249104 -> 140202223248960 + 140202228950336 [label="encoder.layer.8.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228950336 -> 140202223249104 + 140202223249104 [label=AccumulateGrad] + 140202223248288 -> 140202223248960 + 140202228950416 [label="encoder.layer.8.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228950416 -> 140202223248288 + 140202223248288 [label=AccumulateGrad] + 140202223247520 -> 140202223248000 + 140202223247520 [label=TBackward0] + 140202223248576 -> 140202223247520 + 140202223248576 [label=ToCopyBackward0] + 140202223249344 -> 140202223248576 + 140202228934912 [label="encoder.layer.8.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228934912 -> 140202223249344 + 140202223249344 [label=AccumulateGrad] + 140202223230256 -> 140202223230736 + 140202223230256 [label=TBackward0] + 140202223248096 -> 140202223230256 + 140202223248096 [label=ToCopyBackward0] + 140202223248864 -> 140202223248096 + 140202228934592 [label="encoder.layer.8.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228934592 -> 140202223248864 + 140202223248864 [label=AccumulateGrad] + 140202223229920 -> 140202223230016 + 140202223229920 [label=UnsqueezeBackward0] + 140202223230592 -> 140202223229920 + 140202223230592 [label=NativeDropoutBackward0] + 140202223230304 -> 140202223230592 + 140202223230304 [label=ViewBackward0] + 140202223249632 -> 140202223230304 + 140202223249632 [label=AddmmBackward0] + 140202223247904 -> 140202223249632 + 140202223247904 [label=ToCopyBackward0] + 140202223250112 -> 140202223247904 + 140202228934672 [label="encoder.layer.8.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140202228934672 -> 140202223250112 + 140202223250112 [label=AccumulateGrad] + 140202223249056 -> 140202223249632 + 140202223249056 [label=ViewBackward0] + 140202223250304 -> 140202223249056 + 140202223250304 [label=GeluBackward0] + 140202223251024 -> 140202223250304 + 140202223251024 [label=ViewBackward0] + 140202223250544 -> 140202223251024 + 140202223250544 [label=AddmmBackward0] + 140202223285248 -> 140202223250544 + 140202223285248 [label=ToCopyBackward0] + 140202223287840 -> 140202223285248 + 140202228935072 [label="encoder.layer.8.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140202228935072 -> 140202223287840 + 140202223287840 [label=AccumulateGrad] + 140202223284816 -> 140202223250544 + 140202223284816 [label=ViewBackward0] + 140202223315648 -> 140202223284816 + 140202223315648 [label=ToCopyBackward0] + 140202223229152 -> 140202223315648 + 140202223284576 -> 140202223250544 + 140202223284576 [label=TBackward0] + 140202223285920 -> 140202223284576 + 140202223285920 [label=ToCopyBackward0] + 140210811727152 -> 140202223285920 + 140202228934352 [label="encoder.layer.8.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228934352 -> 140210811727152 + 140210811727152 [label=AccumulateGrad] + 140202223247424 -> 140202223249632 + 140202223247424 [label=TBackward0] + 140202223249728 -> 140202223247424 + 140202223249728 [label=ToCopyBackward0] + 140202223313728 -> 140202223249728 + 140202228934112 [label="encoder.layer.8.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228934112 -> 140202223313728 + 140202223313728 [label=AccumulateGrad] + 140202223229776 -> 140202223230016 + 140202223229776 [label=UnsqueezeBackward0] + 140202223285536 -> 140202223229776 + 140202223285536 [label=NativeDropoutBackward0] + 140202223249248 -> 140202223285536 + 140202223249248 [label=ViewBackward0] + 140202223250208 -> 140202223249248 + 140202223250208 [label=AddmmBackward0] + 140202223247616 -> 140202223250208 + 140202223247616 [label=ToCopyBackward0] + 140210811727392 -> 140202223247616 + 140202228934192 [label="encoder.layer.8.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140202228934192 -> 140210811727392 + 140210811727392 [label=AccumulateGrad] + 140210811727296 -> 140202223250208 + 140210811727296 [label=ViewBackward0] + 140210811727440 -> 140210811727296 + 140210811727440 [label=GeluBackward0] + 140210811727536 -> 140210811727440 + 140210811727536 [label=ViewBackward0] + 140210811727632 -> 140210811727536 + 140210811727632 [label=AddmmBackward0] + 140210811727728 -> 140210811727632 + 140210811727728 [label=ToCopyBackward0] + 140210811727824 -> 140210811727728 + 140202228934432 [label="encoder.layer.8.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140202228934432 -> 140210811727824 + 140210811727824 [label=AccumulateGrad] + 140210811727680 -> 140210811727632 + 140210811727680 [label=ViewBackward0] + 140210811781280 -> 140210811727680 + 140210811781280 [label=ToCopyBackward0] + 140202223229152 -> 140210811781280 + 140210811727344 -> 140210811727632 + 140210811727344 [label=TBackward0] + 140210811781232 -> 140210811727344 + 140210811781232 [label=ToCopyBackward0] + 140210811781424 -> 140210811781232 + 140202228933872 [label="encoder.layer.8.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228933872 -> 140210811781424 + 140210811781424 [label=AccumulateGrad] + 140210811727200 -> 140202223250208 + 140210811727200 [label=TBackward0] + 140210811727584 -> 140210811727200 + 140210811727584 [label=ToCopyBackward0] + 140210811727776 -> 140210811727584 + 140202228933632 [label="encoder.layer.8.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228933632 -> 140210811727776 + 140210811727776 [label=AccumulateGrad] + 140202223229728 -> 140202223229440 + 140202223229728 [label=UnsqueezeBackward0] + 140202223230400 -> 140202223229728 + 140202223230400 [label=UnsqueezeBackward0] + 140202223248672 -> 140202223230400 + 140202223248672 [label=SumBackward1] + 140202223229824 -> 140202223248672 + 140202223229824 [label=MulBackward0] + 140210811727488 -> 140202223229824 + 140210811727488 [label=UnsqueezeBackward0] + 140210811781376 -> 140210811727488 + 140210811781376 [label=TopkBackward0] + 140210811781328 -> 140210811781376 + 140210811781328 [label=SoftmaxBackward0] + 140210811781616 -> 140210811781328 + 140210811781616 [label=MmBackward0] + 140210811781712 -> 140210811781616 + 140210811781712 [label=ToCopyBackward0] + 140210811781856 -> 140210811781712 + 140210811781856 [label=MeanBackward1] + 140210811781952 -> 140210811781856 + 140210811781952 [label=MulBackward0] + 140202223229152 -> 140210811781952 + 140210811781664 -> 140210811781616 + 140210811781664 [label=TBackward0] + 140210811782048 -> 140210811781664 + 140210811782048 [label=ToCopyBackward0] + 140210811781760 -> 140210811782048 + 140202228935872 [label="encoder.layer.8.experts.gate.weight + (3, 768)" fillcolor=lightblue] + 140202228935872 -> 140210811781760 + 140210811781760 [label=AccumulateGrad] + 140202223229152 -> 140202223229056 + 140202223228864 -> 140202223228672 + 140202228935792 [label="encoder.layer.8.expert_ln.weight + (768)" fillcolor=lightblue] + 140202228935792 -> 140202223228864 + 140202223228864 [label=AccumulateGrad] + 140202223228816 -> 140202223228672 + 140202228935552 [label="encoder.layer.8.expert_ln.bias + (768)" fillcolor=lightblue] + 140202228935552 -> 140202223228816 + 140202223228816 [label=AccumulateGrad] + 140202223228336 -> 140202223195040 + 140202223228336 [label=NativeLayerNormBackward0] + 140202223229536 -> 140202223228336 + 140202223229536 [label=AddBackward0] + 140202223248384 -> 140202223229536 + 140202223248384 [label=NativeDropoutBackward0] + 140210811727248 -> 140202223248384 + 140210811727248 [label=ViewBackward0] + 140210811781520 -> 140210811727248 + 140210811781520 [label=AddmmBackward0] + 140210811782000 -> 140210811781520 + 140210811782000 [label=ToCopyBackward0] + 140210811782192 -> 140210811782000 + 140202228949936 [label="encoder.layer.8.output.dense.bias + (768)" fillcolor=lightblue] + 140202228949936 -> 140210811782192 + 140210811782192 [label=AccumulateGrad] + 140210811781904 -> 140210811781520 + 140210811781904 [label=ViewBackward0] + 140210811782240 -> 140210811781904 + 140210811782240 [label=GeluBackward0] + 140210811782336 -> 140210811782240 + 140210811782336 [label=ViewBackward0] + 140210811782432 -> 140210811782336 + 140210811782432 [label=AddmmBackward0] + 140210811782528 -> 140210811782432 + 140210811782528 [label=ToCopyBackward0] + 140210811782720 -> 140210811782528 + 140202228950176 [label="encoder.layer.8.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202228950176 -> 140210811782720 + 140210811782720 [label=AccumulateGrad] + 140210811782480 -> 140210811782432 + 140210811782480 [label=ViewBackward0] + 140210811782768 -> 140210811782480 + 140210811782768 [label=ToCopyBackward0] + 140202223230112 -> 140210811782768 + 140202223230112 [label=SliceBackward0] + 140210811782912 -> 140202223230112 + 140210811782912 [label=SliceBackward0] + 140210811783008 -> 140210811782912 + 140210811783008 [label=SliceBackward0] + 140202223287696 -> 140210811783008 + 140210811781808 -> 140210811782432 + 140210811781808 [label=TBackward0] + 140210811782672 -> 140210811781808 + 140210811782672 [label=ToCopyBackward0] + 140210811783104 -> 140210811782672 + 140202228950096 [label="encoder.layer.8.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202228950096 -> 140210811783104 + 140210811783104 [label=AccumulateGrad] + 140210811781184 -> 140210811781520 + 140210811781184 [label=TBackward0] + 140210811782384 -> 140210811781184 + 140210811782384 [label=ToCopyBackward0] + 140210811782864 -> 140210811782384 + 140202228949856 [label="encoder.layer.8.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202228949856 -> 140210811782864 + 140210811782864 [label=AccumulateGrad] + 140202223230112 -> 140202223229536 + 140202223229248 -> 140202223228336 + 140202228949616 [label="encoder.layer.8.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228949616 -> 140202223229248 + 140202223229248 [label=AccumulateGrad] + 140202223228960 -> 140202223228336 + 140202228949696 [label="encoder.layer.8.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228949696 -> 140202223228960 + 140202223228960 [label=AccumulateGrad] + 140202223226992 -> 140202223227904 + 140202223226992 [label=TBackward0] + 140202223228192 -> 140202223226992 + 140202223228192 [label=ToCopyBackward0] + 140210811727056 -> 140202223228192 + 140202228936112 [label="encoder.layer.9.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228936112 -> 140210811727056 + 140210811727056 [label=AccumulateGrad] + 140202223226944 -> 140202223198016 + 140202223226944 [label=ReshapeAliasBackward0] + 140202223227616 -> 140202223226944 + 140202223227616 [label=ExpandBackward0] + 140202223227856 -> 140202223227616 + 140202223227856 [label=TransposeBackward0] + 140202223228768 -> 140202223227856 + 140202223228768 [label=PermuteBackward0] + 140202223228480 -> 140202223228768 + 140202223228480 [label=ViewBackward0] + 140202223227328 -> 140202223228480 + 140202223227328 [label=ViewBackward0] + 140210811782096 -> 140202223227328 + 140210811782096 [label=AddmmBackward0] + 140210811782624 -> 140210811782096 + 140210811782624 [label=ToCopyBackward0] + 140210811782816 -> 140210811782624 + 140202228936272 [label="encoder.layer.9.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202228936272 -> 140210811782816 + 140210811782816 [label=AccumulateGrad] + 140210811782576 -> 140210811782096 + 140210811782576 [label=ViewBackward0] + 140210811783152 -> 140210811782576 + 140210811783152 [label=ToCopyBackward0] + 140202223195040 -> 140210811783152 + 140210811781568 -> 140210811782096 + 140210811781568 [label=TBackward0] + 140210811782288 -> 140210811781568 + 140210811782288 [label=ToCopyBackward0] + 140210811783296 -> 140210811782288 + 140202228936352 [label="encoder.layer.9.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202228936352 -> 140210811783296 + 140210811783296 [label=AccumulateGrad] + 140202223196720 -> 140202223196576 + 140202223196720 [label=ReshapeAliasBackward0] + 140202223197248 -> 140202223196720 + 140202223197248 [label=ExpandBackward0] + 140202223197536 -> 140202223197248 + 140202223197536 [label=PermuteBackward0] + 140202223197824 -> 140202223197536 + 140202223197824 [label=ViewBackward0] + 140202223196960 -> 140202223197824 + 140202223196960 [label=ViewBackward0] + 140202223227808 -> 140202223196960 + 140202223227808 [label=AddmmBackward0] + 140202223229344 -> 140202223227808 + 140202223229344 [label=ToCopyBackward0] + 140210811783248 -> 140202223229344 + 140202228935312 [label="encoder.layer.9.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202228935312 -> 140210811783248 + 140210811783248 [label=AccumulateGrad] + 140202223228288 -> 140202223227808 + 140202223228288 [label=ViewBackward0] + 140210811783056 -> 140202223228288 + 140210811783056 [label=ToCopyBackward0] + 140202223195040 -> 140210811783056 + 140202223227040 -> 140202223227808 + 140202223227040 [label=TBackward0] + 140210811782144 -> 140202223227040 + 140210811782144 [label=ToCopyBackward0] + 140210811783200 -> 140210811782144 + 140202228935632 [label="encoder.layer.9.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202228935632 -> 140210811783200 + 140210811783200 [label=AccumulateGrad] + 140202223195232 -> 140202223195520 + 140202223195232 [label=TBackward0] + 140202223196240 -> 140202223195232 + 140202223196240 [label=ToCopyBackward0] + 140202223196672 -> 140202223196240 + 140202228935392 [label="encoder.layer.9.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228935392 -> 140202223196672 + 140202223196672 [label=AccumulateGrad] + 140202223195040 -> 140202223194656 + 140202223194752 -> 140202223194464 + 140202228933952 [label="encoder.layer.9.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228933952 -> 140202223194752 + 140202223194752 [label=AccumulateGrad] + 140202223194272 -> 140202223194464 + 140202228933712 [label="encoder.layer.9.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228933712 -> 140202223194272 + 140202223194272 [label=AccumulateGrad] + 140202223172000 -> 140202223172480 + 140202223172000 [label=TBackward0] + 140202223173056 -> 140202223172000 + 140202223173056 [label=ToCopyBackward0] + 140202223173536 -> 140202223173056 + 140202228927840 [label="encoder.layer.9.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228927840 -> 140202223173536 + 140202223173536 [label=AccumulateGrad] + 140202223171376 -> 140202223171808 + 140202223171376 [label=TBackward0] + 140202223172576 -> 140202223171376 + 140202223172576 [label=ToCopyBackward0] + 140202223173296 -> 140202223172576 + 140202228927600 [label="encoder.layer.9.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228927600 -> 140202223173296 + 140202223173296 [label=AccumulateGrad] + 140202223171328 -> 140202223170944 + 140202223170896 -> 140202223170848 + 140202228927360 [label="encoder.layer.9.expert_ln.weight + (768)" fillcolor=lightblue] + 140202228927360 -> 140202223170896 + 140202223170896 [label=AccumulateGrad] + 140202223170752 -> 140202223170848 + 140202228927040 [label="encoder.layer.9.expert_ln.bias + (768)" fillcolor=lightblue] + 140202228927040 -> 140202223170752 + 140202223170752 [label=AccumulateGrad] + 140202223170272 -> 140202223136928 + 140202223170272 [label=NativeLayerNormBackward0] + 140202223171424 -> 140202223170272 + 140202223171424 [label=AddBackward0] + 140202223172816 -> 140202223171424 + 140202223172816 [label=NativeDropoutBackward0] + 140202223172336 -> 140202223172816 + 140202223172336 [label=ViewBackward0] + 140202223194176 -> 140202223172336 + 140202223194176 [label=AddmmBackward0] + 140202223194848 -> 140202223194176 + 140202223194848 [label=ToCopyBackward0] + 140202223195760 -> 140202223194848 + 140202228932912 [label="encoder.layer.9.output.dense.bias + (768)" fillcolor=lightblue] + 140202228932912 -> 140202223195760 + 140202223195760 [label=AccumulateGrad] + 140202223194800 -> 140202223194176 + 140202223194800 [label=ViewBackward0] + 140202223196480 -> 140202223194800 + 140202223196480 [label=GeluBackward0] + 140202223195808 -> 140202223196480 + 140202223195808 [label=ViewBackward0] + 140202223197344 -> 140202223195808 + 140202223197344 [label=AddmmBackward0] + 140202223196864 -> 140202223197344 + 140202223196864 [label=ToCopyBackward0] + 140210811782960 -> 140202223196864 + 140202228933152 [label="encoder.layer.9.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202228933152 -> 140210811782960 + 140210811782960 [label=AccumulateGrad] + 140202223197728 -> 140202223197344 + 140202223197728 [label=ViewBackward0] + 140210811783488 -> 140202223197728 + 140210811783488 [label=ToCopyBackward0] + 140202223172096 -> 140210811783488 + 140202223172096 [label=SliceBackward0] + 140210811783536 -> 140202223172096 + 140210811783536 [label=SliceBackward0] + 140210811783632 -> 140210811783536 + 140210811783632 [label=SliceBackward0] + 140202223194464 -> 140210811783632 + 140202223195616 -> 140202223197344 + 140202223195616 [label=TBackward0] + 140210811783344 -> 140202223195616 + 140210811783344 [label=ToCopyBackward0] + 140210811783728 -> 140210811783344 + 140202228933392 [label="encoder.layer.9.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202228933392 -> 140210811783728 + 140210811783728 [label=AccumulateGrad] + 140202223194560 -> 140202223194176 + 140202223194560 [label=TBackward0] + 140202223196192 -> 140202223194560 + 140202223196192 [label=ToCopyBackward0] + 140202223227376 -> 140202223196192 + 140202228933232 [label="encoder.layer.9.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202228933232 -> 140202223227376 + 140202223227376 [label=AccumulateGrad] + 140202223172096 -> 140202223171424 + 140202223171040 -> 140202223170272 + 140202228932992 [label="encoder.layer.9.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228932992 -> 140202223171040 + 140202223171040 [label=AccumulateGrad] + 140202223171136 -> 140202223170272 + 140202228932672 [label="encoder.layer.9.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228932672 -> 140202223171136 + 140202223171136 [label=AccumulateGrad] + 140202223169696 -> 140202223169936 + 140202223169696 [label=TBackward0] + 140202223170368 -> 140202223169696 + 140202223170368 [label=ToCopyBackward0] + 140202223171904 -> 140202223170368 + 140202228927120 [label="encoder.layer.10.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228927120 -> 140202223171904 + 140202223171904 [label=AccumulateGrad] + 140202223140240 -> 140202223140096 + 140202223140240 [label=ReshapeAliasBackward0] + 140202223140480 -> 140202223140240 + 140202223140480 [label=ExpandBackward0] + 140202223140384 -> 140202223140480 + 140202223140384 [label=TransposeBackward0] + 140202223170560 -> 140202223140384 + 140202223170560 [label=PermuteBackward0] + 140202223173152 -> 140202223170560 + 140202223173152 [label=ViewBackward0] + 140202223170656 -> 140202223173152 + 140202223170656 [label=ViewBackward0] + 140202223195328 -> 140202223170656 + 140202223195328 [label=AddmmBackward0] + 140202223197056 -> 140202223195328 + 140202223197056 [label=ToCopyBackward0] + 140210811783440 -> 140202223197056 + 140202228926560 [label="encoder.layer.10.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202228926560 -> 140210811783440 + 140210811783440 [label=AccumulateGrad] + 140202223194320 -> 140202223195328 + 140202223194320 [label=ViewBackward0] + 140210811783776 -> 140202223194320 + 140210811783776 [label=ToCopyBackward0] + 140202223136928 -> 140210811783776 + 140210811781472 -> 140202223195328 + 140210811781472 [label=TBackward0] + 140210811783392 -> 140210811781472 + 140210811783392 [label=ToCopyBackward0] + 140210811783920 -> 140210811783392 + 140202228926880 [label="encoder.layer.10.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202228926880 -> 140210811783920 + 140210811783920 [label=AccumulateGrad] + 140202223138656 -> 140202223138752 + 140202223138656 [label=ReshapeAliasBackward0] + 140202223139280 -> 140202223138656 + 140202223139280 [label=ExpandBackward0] + 140202223139712 -> 140202223139280 + 140202223139712 [label=PermuteBackward0] + 140202223140000 -> 140202223139712 + 140202223140000 [label=ViewBackward0] + 140202223138848 -> 140202223140000 + 140202223138848 [label=ViewBackward0] + 140202223140576 -> 140202223138848 + 140202223140576 [label=AddmmBackward0] + 140202223171520 -> 140202223140576 + 140202223171520 [label=ToCopyBackward0] + 140202223195136 -> 140202223171520 + 140202228926320 [label="encoder.layer.10.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202228926320 -> 140202223195136 + 140202223195136 [label=AccumulateGrad] + 140202223170080 -> 140202223140576 + 140202223170080 [label=ViewBackward0] + 140210811783680 -> 140202223170080 + 140210811783680 [label=ToCopyBackward0] + 140202223136928 -> 140210811783680 + 140202223169792 -> 140202223140576 + 140202223169792 [label=TBackward0] + 140210811783584 -> 140202223169792 + 140210811783584 [label=ToCopyBackward0] + 140210811783824 -> 140210811783584 + 140202228926640 [label="encoder.layer.10.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202228926640 -> 140210811783824 + 140210811783824 [label=AccumulateGrad] + 140202223137024 -> 140202223137408 + 140202223137024 [label=TBackward0] + 140202223138176 -> 140202223137024 + 140202223138176 [label=ToCopyBackward0] + 140202223138464 -> 140202223138176 + 140202228926400 [label="encoder.layer.10.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228926400 -> 140202223138464 + 140202223138464 [label=AccumulateGrad] + 140202223136928 -> 140202223112096 + 140202223111904 -> 140202223112000 + 140202228926160 [label="encoder.layer.10.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228926160 -> 140202223111904 + 140202223111904 [label=AccumulateGrad] + 140202223111232 -> 140202223112000 + 140202228925840 [label="encoder.layer.10.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228925840 -> 140202223111232 + 140202223111232 [label=AccumulateGrad] + 140202223110080 -> 140202223110800 + 140202223110080 [label=TBackward0] + 140202223111136 -> 140202223110080 + 140202223111136 [label=ToCopyBackward0] + 140202223111808 -> 140202223111136 + 140202228925920 [label="encoder.layer.10.crossattention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228925920 -> 140202223111808 + 140202223111808 [label=AccumulateGrad] + 140202223109840 -> 140202223109696 + 140202223109840 [label=ReshapeAliasBackward0] + 140202223110368 -> 140202223109840 + 140202223110368 [label=ExpandBackward0] + 140202223110656 -> 140202223110368 + 140202223110656 [label=TransposeBackward0] + 140202223111424 -> 140202223110656 + 140202223111424 [label=PermuteBackward0] + 140202223111520 -> 140202223111424 + 140202223111520 [label=ViewBackward0] + 140202223109984 -> 140202223111520 + 140202223109984 [label=ViewBackward0] + 140202223137360 -> 140202223109984 + 140202223137360 [label=AddmmBackward0] + 140202223137696 -> 140202223137360 + 140202223137696 [label=ToCopyBackward0] + 140202223137984 -> 140202223137696 + 140202228925360 [label="encoder.layer.10.crossattention.self.key.bias + (768)" fillcolor=lightblue] + 140202228925360 -> 140202223137984 + 140202223137984 [label=AccumulateGrad] + 140202223137792 -> 140202223137360 + 140202223137792 [label=ViewBackward0] + 140202223139520 -> 140202223137792 + 140202223139520 [label=ToCopyBackward0] + 140210812052960 -> 140202223139520 + 140202223136880 -> 140202223137360 + 140202223136880 [label=TBackward0] + 140202223139232 -> 140202223136880 + 140202223139232 [label=ToCopyBackward0] + 140202223139040 -> 140202223139232 + 140202228925680 [label="encoder.layer.10.crossattention.self.key.weight + (768, 1408)" fillcolor=lightblue] + 140202228925680 -> 140202223139040 + 140202223139040 [label=AccumulateGrad] + 140202223108256 -> 140202223082800 + 140202223108256 [label=ReshapeAliasBackward0] + 140202223108880 -> 140202223108256 + 140202223108880 [label=ExpandBackward0] + 140202223109312 -> 140202223108880 + 140202223109312 [label=PermuteBackward0] + 140202223109600 -> 140202223109312 + 140202223109600 [label=ViewBackward0] + 140202223108448 -> 140202223109600 + 140202223108448 [label=ViewBackward0] + 140202223110464 -> 140202223108448 + 140202223110464 [label=AddmmBackward0] + 140202223111616 -> 140202223110464 + 140202223111616 [label=ToCopyBackward0] + 140202223138368 -> 140202223111616 + 140202228925120 [label="encoder.layer.10.crossattention.self.value.bias + (768)" fillcolor=lightblue] + 140202228925120 -> 140202223138368 + 140202223138368 [label=AccumulateGrad] + 140202223110944 -> 140202223110464 + 140202223110944 [label=ViewBackward0] + 140202223140192 -> 140202223110944 + 140202223140192 [label=ToCopyBackward0] + 140210812052960 -> 140202223140192 + 140202223108640 -> 140202223110464 + 140202223108640 [label=TBackward0] + 140202223136832 -> 140202223108640 + 140202223136832 [label=ToCopyBackward0] + 140202223137312 -> 140202223136832 + 140202228925440 [label="encoder.layer.10.crossattention.self.value.weight + (768, 1408)" fillcolor=lightblue] + 140202228925440 -> 140202223137312 + 140202223137312 [label=AccumulateGrad] + 140202223081984 -> 140202223082368 + 140202223081984 [label=TBackward0] + 140202223083136 -> 140202223081984 + 140202223083136 [label=ToCopyBackward0] + 140202223083328 -> 140202223083136 + 140202228925200 [label="encoder.layer.10.crossattention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228925200 -> 140202223083328 + 140202223083328 [label=AccumulateGrad] + 140202223081888 -> 140202223081792 + 140202223081504 -> 140202223081600 + 140202228924960 [label="encoder.layer.10.crossattention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228924960 -> 140202223081504 + 140202223081504 [label=AccumulateGrad] + 140202223080880 -> 140202223081600 + 140202228924640 [label="encoder.layer.10.crossattention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228924640 -> 140202223080880 + 140202223080880 [label=AccumulateGrad] + 140202223080160 -> 140202223080640 + 140202223080160 [label=TBackward0] + 140202223080928 -> 140202223080160 + 140202223080928 [label=ToCopyBackward0] + 140202223081696 -> 140202223080928 + 140202228905040 [label="encoder.layer.10.experts.experts.0.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228905040 -> 140202223081696 + 140202223081696 [label=AccumulateGrad] + 140202223079584 -> 140202223079536 + 140202223079584 [label=TBackward0] + 140202223080448 -> 140202223079584 + 140202223080448 [label=ToCopyBackward0] + 140202223081216 -> 140202223080448 + 140202228905120 [label="encoder.layer.10.experts.experts.0.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228905120 -> 140202223081216 + 140202223081216 [label=AccumulateGrad] + 140202223578608 -> 140202223578464 + 140202223578608 [label=UnsqueezeBackward0] + 140202223579040 -> 140202223578608 + 140202223579040 [label=NativeDropoutBackward0] + 140202223079968 -> 140202223579040 + 140202223079968 [label=ViewBackward0] + 140202223082272 -> 140202223079968 + 140202223082272 [label=AddmmBackward0] + 140202223080256 -> 140202223082272 + 140202223080256 [label=ToCopyBackward0] + 140202223082752 -> 140202223080256 + 140202228904800 [label="encoder.layer.10.experts.experts.1.dense2.bias + (768)" fillcolor=lightblue] + 140202228904800 -> 140202223082752 + 140202223082752 [label=AccumulateGrad] + 140202223081408 -> 140202223082272 + 140202223081408 [label=ViewBackward0] + 140202223082656 -> 140202223081408 + 140202223082656 [label=GeluBackward0] + 140202223081312 -> 140202223082656 + 140202223081312 [label=ViewBackward0] + 140202223108352 -> 140202223081312 + 140202223108352 [label=AddmmBackward0] + 140202223109360 -> 140202223108352 + 140202223109360 [label=ToCopyBackward0] + 140202223169600 -> 140202223109360 + 140202228905600 [label="encoder.layer.10.experts.experts.1.dense1.bias + (3072)" fillcolor=lightblue] + 140202228905600 -> 140202223169600 + 140202223169600 [label=AccumulateGrad] + 140202223109120 -> 140202223108352 + 140202223109120 [label=ViewBackward0] + 140202223139760 -> 140202223109120 + 140202223139760 [label=ToCopyBackward0] + 140202223577888 -> 140202223139760 + 140202223108832 -> 140202223108352 + 140202223108832 [label=TBackward0] + 140202223110176 -> 140202223108832 + 140202223110176 [label=ToCopyBackward0] + 140210811783968 -> 140202223110176 + 140202228904880 [label="encoder.layer.10.experts.experts.1.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228904880 -> 140210811783968 + 140210811783968 [label=AccumulateGrad] + 140202223079488 -> 140202223082272 + 140202223079488 [label=TBackward0] + 140202223082320 -> 140202223079488 + 140202223082320 [label=ToCopyBackward0] + 140202223137840 -> 140202223082320 + 140202228904640 [label="encoder.layer.10.experts.experts.1.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228904640 -> 140202223137840 + 140202223137840 [label=AccumulateGrad] + 140202223578272 -> 140202223578464 + 140202223578272 [label=UnsqueezeBackward0] + 140202223080736 -> 140202223578272 + 140202223080736 [label=NativeDropoutBackward0] + 140202223082944 -> 140202223080736 + 140202223082944 [label=ViewBackward0] + 140202223108160 -> 140202223082944 + 140202223108160 [label=AddmmBackward0] + 140202223079680 -> 140202223108160 + 140202223079680 [label=ToCopyBackward0] + 140210811784208 -> 140202223079680 + 140202228895552 [label="encoder.layer.10.experts.experts.2.dense2.bias + (768)" fillcolor=lightblue] + 140202228895552 -> 140210811784208 + 140210811784208 [label=AccumulateGrad] + 140210811784112 -> 140202223108160 + 140210811784112 [label=ViewBackward0] + 140210811784256 -> 140210811784112 + 140210811784256 [label=GeluBackward0] + 140210811784352 -> 140210811784256 + 140210811784352 [label=ViewBackward0] + 140210811784448 -> 140210811784352 + 140210811784448 [label=AddmmBackward0] + 140210811784544 -> 140210811784448 + 140210811784544 [label=ToCopyBackward0] + 140210811784736 -> 140210811784544 + 140202228904560 [label="encoder.layer.10.experts.experts.2.dense1.bias + (3072)" fillcolor=lightblue] + 140202228904560 -> 140210811784736 + 140210811784736 [label=AccumulateGrad] + 140210811784496 -> 140210811784448 + 140210811784496 [label=ViewBackward0] + 140210811784784 -> 140210811784496 + 140210811784784 [label=ToCopyBackward0] + 140202223577888 -> 140210811784784 + 140210811784160 -> 140210811784448 + 140210811784160 [label=TBackward0] + 140210811784640 -> 140210811784160 + 140210811784640 [label=ToCopyBackward0] + 140210811784928 -> 140210811784640 + 140202228904400 [label="encoder.layer.10.experts.experts.2.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228904400 -> 140210811784928 + 140210811784928 [label=AccumulateGrad] + 140210811784016 -> 140202223108160 + 140210811784016 [label=TBackward0] + 140210811784400 -> 140210811784016 + 140210811784400 [label=ToCopyBackward0] + 140210811784880 -> 140210811784400 + 140202228904320 [label="encoder.layer.10.experts.experts.2.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228904320 -> 140210811784880 + 140210811784880 [label=AccumulateGrad] + 140202223578176 -> 140202223578128 + 140202223578176 [label=UnsqueezeBackward0] + 140202223578848 -> 140202223578176 + 140202223578848 [label=UnsqueezeBackward0] + 140202223109792 -> 140202223578848 + 140202223109792 [label=SumBackward1] + 140202223079920 -> 140202223109792 + 140202223079920 [label=MulBackward0] + 140210811785024 -> 140202223079920 + 140210811785024 [label=UnsqueezeBackward0] + 140210811784304 -> 140210811785024 + 140210811784304 [label=TopkBackward0] + 140210811784832 -> 140210811784304 + 140210811784832 [label=SoftmaxBackward0] + 140210811785120 -> 140210811784832 + 140210811785120 [label=MmBackward0] + 140210811785168 -> 140210811785120 + 140210811785168 [label=ToCopyBackward0] + 140210811850960 -> 140210811785168 + 140210811850960 [label=MeanBackward1] + 140210811851056 -> 140210811850960 + 140210811851056 [label=MulBackward0] + 140202223577888 -> 140210811851056 + 140210811784064 -> 140210811785120 + 140210811784064 [label=TBackward0] + 140210811851152 -> 140210811784064 + 140210811851152 [label=ToCopyBackward0] + 140210811850864 -> 140210811851152 + 140202228906000 [label="encoder.layer.10.experts.gate.weight + (3, 768)" fillcolor=lightblue] + 140202228906000 -> 140210811850864 + 140210811850864 [label=AccumulateGrad] + 140202223577888 -> 140202223577504 + 140202223577600 -> 140202223577408 + 140202228906320 [label="encoder.layer.10.expert_ln.weight + (768)" fillcolor=lightblue] + 140202228906320 -> 140202223577600 + 140202223577600 [label=AccumulateGrad] + 140202223577312 -> 140202223577408 + 140202228906080 [label="encoder.layer.10.expert_ln.bias + (768)" fillcolor=lightblue] + 140202228906080 -> 140202223577312 + 140202223577312 [label=AccumulateGrad] + 140202223576832 -> 140202223540112 + 140202223576832 [label=NativeLayerNormBackward0] + 140202223577984 -> 140202223576832 + 140202223577984 [label=AddBackward0] + 140202223081840 -> 140202223577984 + 140202223081840 [label=NativeDropoutBackward0] + 140210811784976 -> 140202223081840 + 140210811784976 [label=ViewBackward0] + 140210811785072 -> 140210811784976 + 140210811785072 [label=AddmmBackward0] + 140210811851104 -> 140210811785072 + 140210811851104 [label=ToCopyBackward0] + 140210811851296 -> 140210811851104 + 140202228907680 [label="encoder.layer.10.output.dense.bias + (768)" fillcolor=lightblue] + 140202228907680 -> 140210811851296 + 140210811851296 [label=AccumulateGrad] + 140210811851008 -> 140210811785072 + 140210811851008 [label=ViewBackward0] + 140210811851344 -> 140210811851008 + 140210811851344 [label=GeluBackward0] + 140210811851440 -> 140210811851344 + 140210811851440 [label=ViewBackward0] + 140210811851536 -> 140210811851440 + 140210811851536 [label=AddmmBackward0] + 140210811851632 -> 140210811851536 + 140210811851632 [label=ToCopyBackward0] + 140210811851824 -> 140210811851632 + 140202228924480 [label="encoder.layer.10.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202228924480 -> 140210811851824 + 140210811851824 [label=AccumulateGrad] + 140210811851584 -> 140210811851536 + 140210811851584 [label=ViewBackward0] + 140210811851872 -> 140210811851584 + 140210811851872 [label=ToCopyBackward0] + 140202223578560 -> 140210811851872 + 140202223578560 [label=SliceBackward0] + 140210811852016 -> 140202223578560 + 140210811852016 [label=SliceBackward0] + 140210811852112 -> 140210811852016 + 140210811852112 [label=SliceBackward0] + 140202223112000 -> 140210811852112 + 140210811850912 -> 140210811851536 + 140210811850912 [label=TBackward0] + 140210811851776 -> 140210811850912 + 140210811851776 [label=ToCopyBackward0] + 140210811852208 -> 140210811851776 + 140202228924720 [label="encoder.layer.10.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202228924720 -> 140210811852208 + 140210811852208 [label=AccumulateGrad] + 140210811850816 -> 140210811785072 + 140210811850816 [label=TBackward0] + 140210811851488 -> 140210811850816 + 140210811851488 [label=ToCopyBackward0] + 140210811851968 -> 140210811851488 + 140202228907920 [label="encoder.layer.10.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202228907920 -> 140210811851968 + 140210811851968 [label=AccumulateGrad] + 140202223578560 -> 140202223577984 + 140202223577696 -> 140202223576832 + 140202228907760 [label="encoder.layer.10.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228907760 -> 140202223577696 + 140202223577696 [label=AccumulateGrad] + 140202223577648 -> 140202223576832 + 140202228907440 [label="encoder.layer.10.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228907440 -> 140202223577648 + 140202223577648 [label=AccumulateGrad] + 140202223575728 -> 140202223576640 + 140202223575728 [label=TBackward0] + 140202223576928 -> 140202223575728 + 140202223576928 [label=ToCopyBackward0] + 140202223578080 -> 140202223576928 + 140202228906240 [label="encoder.layer.11.attention.self.query.weight + (768, 768)" fillcolor=lightblue] + 140202228906240 -> 140202223578080 + 140202223578080 [label=AccumulateGrad] + 140202223575680 -> 140202223575392 + 140202223575680 [label=ReshapeAliasBackward0] + 140202223576064 -> 140202223575680 + 140202223576064 [label=ExpandBackward0] + 140202223576352 -> 140202223576064 + 140202223576352 [label=TransposeBackward0] + 140202223577216 -> 140202223576352 + 140202223577216 [label=PermuteBackward0] + 140202223577168 -> 140202223577216 + 140202223577168 [label=ViewBackward0] + 140210811783872 -> 140202223577168 + 140210811783872 [label=ViewBackward0] + 140210811784592 -> 140210811783872 + 140210811784592 [label=AddmmBackward0] + 140210811851728 -> 140210811784592 + 140210811851728 [label=ToCopyBackward0] + 140210811851920 -> 140210811851728 + 140202228906800 [label="encoder.layer.11.attention.self.key.bias + (768)" fillcolor=lightblue] + 140202228906800 -> 140210811851920 + 140210811851920 [label=AccumulateGrad] + 140210811851680 -> 140210811784592 + 140210811851680 [label=ViewBackward0] + 140210811852256 -> 140210811851680 + 140210811852256 [label=ToCopyBackward0] + 140202223540112 -> 140210811852256 + 140210811851248 -> 140210811784592 + 140210811851248 [label=TBackward0] + 140210811851392 -> 140210811851248 + 140210811851392 [label=ToCopyBackward0] + 140210811852400 -> 140210811851392 + 140202228906480 [label="encoder.layer.11.attention.self.key.weight + (768, 768)" fillcolor=lightblue] + 140202228906480 -> 140210811852400 + 140210811852400 [label=AccumulateGrad] + 140202223541168 -> 140202223541312 + 140202223541168 [label=ReshapeAliasBackward0] + 140202223541888 -> 140202223541168 + 140202223541888 [label=ExpandBackward0] + 140202223542176 -> 140202223541888 + 140202223542176 [label=PermuteBackward0] + 140202223541456 -> 140202223542176 + 140202223541456 [label=ViewBackward0] + 140202223575200 -> 140202223541456 + 140202223575200 [label=ViewBackward0] + 140202223576256 -> 140202223575200 + 140202223576256 [label=AddmmBackward0] + 140202223575776 -> 140202223576256 + 140202223575776 [label=ToCopyBackward0] + 140210811852352 -> 140202223575776 + 140202228905840 [label="encoder.layer.11.attention.self.value.bias + (768)" fillcolor=lightblue] + 140202228905840 -> 140210811852352 + 140210811852352 [label=AccumulateGrad] + 140202223576736 -> 140202223576256 + 140202223576736 [label=ViewBackward0] + 140210811852160 -> 140202223576736 + 140210811852160 [label=ToCopyBackward0] + 140202223540112 -> 140210811852160 + 140202223575104 -> 140202223576256 + 140202223575104 [label=TBackward0] + 140210811851200 -> 140202223575104 + 140210811851200 [label=ToCopyBackward0] + 140210811852304 -> 140210811851200 + 140202228905760 [label="encoder.layer.11.attention.self.value.weight + (768, 768)" fillcolor=lightblue] + 140202228905760 -> 140210811852304 + 140210811852304 [label=AccumulateGrad] + 140202223540208 -> 140202223540400 + 140202223540208 [label=TBackward0] + 140202223540880 -> 140202223540208 + 140202223540880 [label=ToCopyBackward0] + 140202223541072 -> 140202223540880 + 140202228905520 [label="encoder.layer.11.attention.output.dense.weight + (768, 768)" fillcolor=lightblue] + 140202228905520 -> 140202223541072 + 140202223541072 [label=AccumulateGrad] + 140202223540112 -> 140202223540160 + 140202223539920 -> 140202223540064 + 140202228904160 [label="encoder.layer.11.attention.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228904160 -> 140202223539920 + 140202223539920 [label=AccumulateGrad] + 140202223538576 -> 140202223540064 + 140202228895312 [label="encoder.layer.11.attention.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228895312 -> 140202223538576 + 140202223538576 [label=AccumulateGrad] + 140202223539056 -> 140202223539680 + 140202223539056 [label=TBackward0] + 140202223538384 -> 140202223539056 + 140202223538384 [label=ToCopyBackward0] + 140202223539632 -> 140202223538384 + 140202228893872 [label="encoder.layer.11.experts.dense1.weight + (3072, 768)" fillcolor=lightblue] + 140202228893872 -> 140202223539632 + 140202223539632 [label=AccumulateGrad] + 140202223539344 -> 140202223539008 + 140202223539344 [label=TBackward0] + 140202223539488 -> 140202223539344 + 140202223539488 [label=ToCopyBackward0] + 140202223538240 -> 140202223539488 + 140202228893632 [label="encoder.layer.11.experts.dense2.weight + (768, 3072)" fillcolor=lightblue] + 140202228893632 -> 140202223538240 + 140202223538240 [label=AccumulateGrad] + 140202223538480 -> 140202228614096 + 140202228614192 -> 140202228615488 + 140202228893392 [label="encoder.layer.11.expert_ln.weight + (768)" fillcolor=lightblue] + 140202228893392 -> 140202228614192 + 140202228614192 [label=AccumulateGrad] + 140202228614336 -> 140202228615488 + 140202228893472 [label="encoder.layer.11.expert_ln.bias + (768)" fillcolor=lightblue] + 140202228893472 -> 140202228614336 + 140202228614336 [label=AccumulateGrad] + 140202228614480 -> 140202228657312 + 140202228614480 [label=NativeLayerNormBackward0] + 140202228614432 -> 140202228614480 + 140202228614432 [label=AddBackward0] + 140202223538816 -> 140202228614432 + 140202223538816 [label=NativeDropoutBackward0] + 140202223539392 -> 140202223538816 + 140202223539392 [label=ViewBackward0] + 140202223538432 -> 140202223539392 + 140202223538432 [label=AddmmBackward0] + 140202223540256 -> 140202223538432 + 140202223540256 [label=ToCopyBackward0] + 140202223540592 -> 140202223540256 + 140202228895152 [label="encoder.layer.11.output.dense.bias + (768)" fillcolor=lightblue] + 140202228895152 -> 140202223540592 + 140202223540592 [label=AccumulateGrad] + 140202223540016 -> 140202223538432 + 140202223540016 [label=ViewBackward0] + 140202223540976 -> 140202223540016 + 140202223540976 [label=GeluBackward0] + 140202223540832 -> 140202223540976 + 140202223540832 [label=ViewBackward0] + 140202223541936 -> 140202223540832 + 140202223541936 [label=AddmmBackward0] + 140202223540736 -> 140202223541936 + 140202223540736 [label=ToCopyBackward0] + 140210811784688 -> 140202223540736 + 140202228895392 [label="encoder.layer.11.intermediate.dense.bias + (3072)" fillcolor=lightblue] + 140202228895392 -> 140210811784688 + 140210811784688 [label=AccumulateGrad] + 140202223575488 -> 140202223541936 + 140202223575488 [label=ViewBackward0] + 140210811852592 -> 140202223575488 + 140210811852592 [label=ToCopyBackward0] + 140202223539200 -> 140210811852592 + 140202223539200 [label=SliceBackward0] + 140210811852640 -> 140202223539200 + 140210811852640 [label=SliceBackward0] + 140210811852736 -> 140210811852640 + 140210811852736 [label=SliceBackward0] + 140202223540064 -> 140210811852736 + 140202223575248 -> 140202223541936 + 140202223575248 [label=TBackward0] + 140210811852064 -> 140202223575248 + 140210811852064 [label=ToCopyBackward0] + 140210811852832 -> 140210811852064 + 140202228895632 [label="encoder.layer.11.intermediate.dense.weight + (3072, 768)" fillcolor=lightblue] + 140202228895632 -> 140210811852832 + 140210811852832 [label=AccumulateGrad] + 140202223539824 -> 140202223538432 + 140202223539824 [label=TBackward0] + 140202223540784 -> 140202223539824 + 140202223540784 [label=ToCopyBackward0] + 140202223575872 -> 140202223540784 + 140202228895072 [label="encoder.layer.11.output.dense.weight + (768, 3072)" fillcolor=lightblue] + 140202228895072 -> 140202223575872 + 140202223575872 [label=AccumulateGrad] + 140202223539200 -> 140202228614432 + 140202223538672 -> 140202228614480 + 140202228894832 [label="encoder.layer.11.output.LayerNorm.weight + (768)" fillcolor=lightblue] + 140202228894832 -> 140202223538672 + 140202223538672 [label=AccumulateGrad] + 140202223538624 -> 140202228614480 + 140202228894912 [label="encoder.layer.11.output.LayerNorm.bias + (768)" fillcolor=lightblue] + 140202228894912 -> 140202223538624 + 140202223538624 [label=AccumulateGrad] + 140202228657312 -> 140202223089520 +} diff --git a/Pre_PromptMoE_RawProb_backward_graph.pdf b/Pre_PromptMoE_RawProb_backward_graph.pdf new file mode 100644 index 0000000..54f7e67 Binary files /dev/null and b/Pre_PromptMoE_RawProb_backward_graph.pdf differ diff --git a/command.txt b/command.txt new file mode 100644 index 0000000..cee63ca --- /dev/null +++ b/command.txt @@ -0,0 +1,2 @@ +chmod +x *.sh +tensorboard --bind_all --logdir \ No newline at end of file diff --git a/demo.py b/demo.py index c7646c4..5bb7665 100644 --- a/demo.py +++ b/demo.py @@ -116,6 +116,7 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): max_new_tokens=300, max_length=2000)[0] chatbot[-1][1] = llm_message + print(llm_message) return chatbot, chat_state, img_list diff --git a/demo_v2.py b/demo_v2.py index 5e2deeb..e77fc40 100644 --- a/demo_v2.py +++ b/demo_v2.py @@ -30,7 +30,7 @@ from minigpt4.tasks import * def parse_args(): parser = argparse.ArgumentParser(description="Demo") - parser.add_argument("--cfg-path", default='eval_configs/minigptv2_eval.yaml', + parser.add_argument("--cfg-path", default='minigpt4/projects/minigpt/eval/minigptv2_eval.yaml', help="path to configuration file.") parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.") parser.add_argument( diff --git a/CODE_OF_CONDUCT.md b/documents/CODE_OF_CONDUCT.md similarity index 100% rename from CODE_OF_CONDUCT.md rename to documents/CODE_OF_CONDUCT.md diff --git a/LICENSE.md b/documents/LICENSE.md similarity index 100% rename from LICENSE.md rename to documents/LICENSE.md diff --git a/LICENSE_Lavis.md b/documents/LICENSE_Lavis.md similarity index 100% rename from LICENSE_Lavis.md rename to documents/LICENSE_Lavis.md diff --git a/MiniGPT4_Train.md b/documents/MiniGPT4_Train.md similarity index 77% rename from MiniGPT4_Train.md rename to documents/MiniGPT4_Train.md index f9e8a5c..3fc0f41 100644 --- a/MiniGPT4_Train.md +++ b/documents/MiniGPT4_Train.md @@ -11,10 +11,10 @@ After the first stage, the visual features are mapped and can be understood by t model. To launch the first stage training, run the following command. In our experiments, we use 4 A100. You can change the save path in the config file -[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage1_pretrain.yaml) +[minigpt4/projects/minigpt/train/minigpt4_stage1_pretrain.yaml](minigpt4/projects/minigpt/train/minigpt4_stage1_pretrain.yaml) ```bash -torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage1_pretrain.yaml +torchrun --nproc-per-node NUM_GPU train.py --cfg-path minigpt4/projects/minigpt/train/minigpt4_stage1_pretrain.yaml ``` A MiniGPT-4 checkpoint with only stage one training can be downloaded @@ -30,12 +30,12 @@ To download and prepare our second stage dataset, please check our [second stage dataset preparation instruction](dataset/README_2_STAGE.md). To launch the second stage alignment, first specify the path to the checkpoint file trained in stage 1 in -[train_configs/minigpt4_stage1_pretrain.yaml](train_configs/minigpt4_stage2_finetune.yaml). +[minigpt4/projects/minigpt/train/minigpt4_stage1_pretrain.yaml](minigpt4/projects/minigpt/train/minigpt4_stage2_finetune.yaml). You can also specify the output path there. Then, run the following command. In our experiments, we use 1 A100. ```bash -torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigpt4_stage2_finetune.yaml +torchrun --nproc-per-node NUM_GPU train.py --cfg-path minigpt4/projects/minigpt/train/minigpt4_stage2_finetune.yaml ``` After the second stage alignment, MiniGPT-4 is able to talk about the image coherently and user-friendly. diff --git a/MiniGPTv2_Train.md b/documents/MiniGPTv2_Train.md similarity index 80% rename from MiniGPTv2_Train.md rename to documents/MiniGPTv2_Train.md index 254d680..f77171c 100644 --- a/MiniGPTv2_Train.md +++ b/documents/MiniGPTv2_Train.md @@ -4,7 +4,7 @@ You firstly need to prepare the dataset. you can follow this step to prepare the dataset. our [dataset preparation](dataset/README_MINIGPTv2_FINETUNE.md). -In the train_configs/minigptv2_finetune.yaml, you need to set up the following paths: +In the minigpt4/projects/minigpt/train/minigptv2_finetune.yaml, you need to set up the following paths: llama_model checkpoint path: "/path/to/llama_checkpoint" @@ -19,6 +19,6 @@ For ckpt, you may load from our pretrained model checkpoints: ```bash -torchrun --nproc-per-node NUM_GPU train.py --cfg-path train_configs/minigptv2_finetune.yaml +torchrun --nproc-per-node NUM_GPU train.py --cfg-path minigpt4/projects/minigpt/train/minigptv2_finetune.yaml ``` diff --git a/README.md b/documents/README.md similarity index 87% rename from README.md rename to documents/README.md index d24923d..9c79289 100644 --- a/README.md +++ b/documents/README.md @@ -82,13 +82,13 @@ Download the corresponding LLM weights from the following huggingface space via Then, set the variable *llama_model* in the model config file to the LLM weight path. * For MiniGPT-v2, set the LLM path -[here](minigpt4/configs/models/minigpt_v2.yaml#L15) at Line 14. +[here](minigpt4/configs/models/minigpt/minigpt_v2.yaml#L15) at Line 14. * For MiniGPT-4 (Llama2), set the LLM path -[here](minigpt4/configs/models/minigpt4_llama2.yaml#L15) at Line 15. +[here](minigpt4/configs/models/minigpt/minigpt4_llama2.yaml#L15) at Line 15. * For MiniGPT-4 (Vicuna), set the LLM path -[here](minigpt4/configs/models/minigpt4_vicuna0.yaml#L18) at Line 18 +[here](minigpt4/configs/models/minigpt/minigpt4_vicuna0.yaml#L18) at Line 18 **3. Prepare the pretrained model checkpoints** @@ -101,7 +101,7 @@ Download the pretrained model checkpoints For **MiniGPT-v2**, set the path to the pretrained checkpoint in the evaluation config file -in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at Line 8. +in [minigpt4/projects/minigpt/eval/minigptv2_eval.yaml](minigpt4/projects/minigpt/eval/minigptv2_eval.yaml#L10) at Line 8. @@ -110,7 +110,7 @@ in [eval_configs/minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#L10) at L | [Download](https://drive.google.com/file/d/1a4zLvaiDBr-36pasffmgpvH5P7CKmpze/view?usp=share_link) | [Download](https://drive.google.com/file/d/1RY9jV0dyqLX-o38LrumkKRh6Jtaop58R/view?usp=sharing) | [Download](https://drive.google.com/file/d/11nAPjEok8eAGGEG1N2vXo3kBLCg0WgUk/view?usp=sharing) | For **MiniGPT-4**, set the path to the pretrained checkpoint in the evaluation config file -in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [eval_configs/minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#L10) for LLama2 version. +in [minigpt4/projects/minigpt/eval/minigpt4_eval.yaml](minigpt4/projects/minigpt/eval/minigpt4_eval.yaml#L10) at Line 8 for Vicuna version or [minigpt4/projects/minigpt/eval/minigpt4_llama2_eval.yaml](minigpt4/projects/minigpt/eval/minigpt4_llama2_eval.yaml#L10) for LLama2 version. @@ -118,19 +118,19 @@ in [eval_configs/minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#L10) at Lin For MiniGPT-v2, run ``` -python demo_v2.py --cfg-path eval_configs/minigptv2_eval.yaml --gpu-id 0 +python demo_v2.py --cfg-path minigpt4/projects/minigpt/eval/minigptv2_eval.yaml --gpu-id 0 ``` For MiniGPT-4 (Vicuna version), run ``` -python demo.py --cfg-path eval_configs/minigpt4_eval.yaml --gpu-id 0 +python demo.py --cfg-path minigpt4/projects/minigpt/eval/minigpt4_eval.yaml --gpu-id 0 ``` For MiniGPT-4 (Llama2 version), run ``` -python demo.py --cfg-path eval_configs/minigpt4_llama2_eval.yaml --gpu-id 0 +python demo.py --cfg-path minigpt4/projects/minigpt/eval/minigpt4_llama2_eval.yaml --gpu-id 0 ``` @@ -139,9 +139,9 @@ This configuration requires about 23G GPU memory for 13B LLM and 11.5G GPU memor For more powerful GPUs, you can run the model in 16 bit by setting `low_resource` to `False` in the relevant config file: -* MiniGPT-v2: [minigptv2_eval.yaml](eval_configs/minigptv2_eval.yaml#6) -* MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](eval_configs/minigpt4_llama2_eval.yaml#6) -* MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](eval_configs/minigpt4_eval.yaml#6) +* MiniGPT-v2: [minigptv2_eval.yaml](minigpt4/projects/minigpt/eval/minigptv2_eval.yaml#6) +* MiniGPT-4 (Llama2): [minigpt4_llama2_eval.yaml](minigpt4/projects/minigpt/eval/minigpt4_llama2_eval.yaml#6) +* MiniGPT-4 (Vicuna): [minigpt4_eval.yaml](minigpt4/projects/minigpt/eval/minigpt4_eval.yaml#6) Thanks [@WangRongsheng](https://github.com/WangRongsheng), you can also run MiniGPT-4 on [Colab](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) diff --git a/SECURITY.md b/documents/SECURITY.md similarity index 100% rename from SECURITY.md rename to documents/SECURITY.md diff --git a/environment.yml b/environment.yml index 8f94afe..5230311 100644 --- a/environment.yml +++ b/environment.yml @@ -1,4 +1,4 @@ -name: minigptv +name: promptmoe channels: - pytorch - defaults @@ -31,3 +31,5 @@ dependencies: - accelerate==0.20.3 - bitsandbytes==0.37.0 - wandb + - visual_genome + - scikit-image diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..2ce4a3d --- /dev/null +++ b/evaluate.py @@ -0,0 +1,92 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank, init_distributed_mode +from minigpt4.common.logger import setup_logger +from minigpt4.common.optims import ( + LinearWarmupCosineLRScheduler, + LinearWarmupStepLRScheduler, +) +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners.runner_base import RunnerBase +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Training") + + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + + args = parser.parse_args() + # if 'LOCAL_RANK' not in os.environ: + # os.environ['LOCAL_RANK'] = str(args.local_rank) + + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +def main(): + # allow auto-dl completes on main process without timeout when using NCCL backend. + # os.environ["NCCL_BLOCKING_WAIT"] = "1" + + # set before init_distributed_mode() to ensure the same job_id shared across all ranks. + job_id = now() + + cfg = Config(parse_args()) + + init_distributed_mode(cfg.run_cfg) + + setup_seeds(cfg) + + # set after init_distributed_mode() to only log on master. + setup_logger() + + cfg.pretty_print() + + task = tasks.setup_task(cfg) + datasets = task.build_datasets(cfg) + model = task.build_model(cfg) + + runner = RunnerBase( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets + ) + runner.evaluate(skip_reload=True) + + +if __name__ == "__main__": + main() diff --git a/evaluation/coco_caption.py b/evaluation/coco_caption.py new file mode 100644 index 0000000..b6179b8 --- /dev/null +++ b/evaluation/coco_caption.py @@ -0,0 +1,94 @@ +import os +import json +import pandas as pd +from tqdm import tqdm + +from pycocoevalcap.eval import COCOEvalCap +from collections import defaultdict + +class COCO_Annotation: + def __init__(self, annotation_file): + self.coco_cn_file = annotation_file + self.imgToAnns = self.build_imgToAnns() + + def build_imgToAnns(self): + imgToAnns = defaultdict(list) + with open(self.coco_cn_file, "r", encoding="UTF-8") as fin: + for line in fin: + line = line.strip() + temp = eval(line) + annotations = temp['annotations'] + for ann in annotations: + image_id = str(ann['image_id']).zfill(6) + imgToAnns[image_id].append({'image_id':image_id,'caption':ann['caption'],'image': ann['image_id']}) + return imgToAnns + + def getImgIds(self): + return self.imgToAnns.keys() + +class COCO_Result: + def __init__(self,result_file): + self.coco_cn_file = result_file + self.imgToAnns = self.build_imgToAnns() + + def build_imgToAnns(self): + imgToAnns = dict() + data = json.load(open(self.coco_cn_file, "r")) + for d in data: + tmp = { + 'image_id':d['question_id'][-6:], + 'caption':d['answer'] + } + imgToAnns[d['question_id'][-6:]] = [tmp] + return imgToAnns + +def coco_caption_eval(results_file, split_name): + files = { + "val":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val_gt.json", + "test":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test_gt.json" + } + + # create coco object and coco_result object + annotation_file = files[split_name] + coco = COCO_Annotation(annotation_file) + coco_result = COCO_Result(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # evaluate on a subset of images by setting + # coco_eval.params['image_id'] = coco_result.getImgIds() + # please remove this line when evaluating the full validation set + # coco_eval.params['image_id'] = coco_result.getImgIds() + + # evaluate results + # SPICE will take a few minutes the first time, but speeds up due to caching + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f"{metric}: {score:.3f}") + + return coco_eval + + +def main(): + result_file = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/mix_coco_gqa_cap_raw_QformerMoE_Post_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0302/20240302231/result/val_vqa_result_coco_cap.json" + split_name = "val" + coco_val = coco_caption_eval(result_file, split_name) + + agg_metrics = coco_val.eval["CIDEr"] + coco_val.eval["Bleu_4"] + + # log_stats = {split_name: {k: v for k, v in coco_val.eval.items()}} + # with open( + # os.path.join(registry.get_path("output_dir"), "evaluate.txt"), "a" + # ) as f: + # f.write(json.dumps(log_stats) + "\n") + + coco_res = {k: v for k, v in coco_val.eval.items()} + coco_res["agg_metrics"] = agg_metrics + + print(coco_res) + + +main() \ No newline at end of file diff --git a/examples/ad_1.png b/examples/ad_1.png deleted file mode 100644 index d0378e4..0000000 Binary files a/examples/ad_1.png and /dev/null differ diff --git a/examples/ad_2.png b/examples/ad_2.png deleted file mode 100644 index 674248b..0000000 Binary files a/examples/ad_2.png and /dev/null differ diff --git a/examples/cook_1.png b/examples/cook_1.png deleted file mode 100644 index d8cdb45..0000000 Binary files a/examples/cook_1.png and /dev/null differ diff --git a/examples/cook_2.png b/examples/cook_2.png deleted file mode 100644 index d08272b..0000000 Binary files a/examples/cook_2.png and /dev/null differ diff --git a/examples/describe_1.png b/examples/describe_1.png deleted file mode 100644 index 02f3c92..0000000 Binary files a/examples/describe_1.png and /dev/null differ diff --git a/examples/describe_2.png b/examples/describe_2.png deleted file mode 100644 index 20bf8c7..0000000 Binary files a/examples/describe_2.png and /dev/null differ diff --git a/examples/fact_1.png b/examples/fact_1.png deleted file mode 100644 index 1f75228..0000000 Binary files a/examples/fact_1.png and /dev/null differ diff --git a/examples/fact_2.png b/examples/fact_2.png deleted file mode 100644 index de6ef53..0000000 Binary files a/examples/fact_2.png and /dev/null differ diff --git a/examples/fix_1.png b/examples/fix_1.png deleted file mode 100644 index 023cfe6..0000000 Binary files a/examples/fix_1.png and /dev/null differ diff --git a/examples/fix_2.png b/examples/fix_2.png deleted file mode 100644 index f60da5f..0000000 Binary files a/examples/fix_2.png and /dev/null differ diff --git a/examples/fun_1.png b/examples/fun_1.png deleted file mode 100644 index f720ea6..0000000 Binary files a/examples/fun_1.png and /dev/null differ diff --git a/examples/fun_2.png b/examples/fun_2.png deleted file mode 100644 index 1d37a80..0000000 Binary files a/examples/fun_2.png and /dev/null differ diff --git a/examples/logo_1.png b/examples/logo_1.png deleted file mode 100644 index 8bbe438..0000000 Binary files a/examples/logo_1.png and /dev/null differ diff --git a/examples/op_1.png b/examples/op_1.png deleted file mode 100644 index 3dbb2ff..0000000 Binary files a/examples/op_1.png and /dev/null differ diff --git a/examples/op_2.png b/examples/op_2.png deleted file mode 100644 index 2cd3e1f..0000000 Binary files a/examples/op_2.png and /dev/null differ diff --git a/examples/people_1.png b/examples/people_1.png deleted file mode 100644 index 7e95c42..0000000 Binary files a/examples/people_1.png and /dev/null differ diff --git a/examples/people_2.png b/examples/people_2.png deleted file mode 100644 index aec6c83..0000000 Binary files a/examples/people_2.png and /dev/null differ diff --git a/examples/rhyme_1.png b/examples/rhyme_1.png deleted file mode 100644 index 7d13387..0000000 Binary files a/examples/rhyme_1.png and /dev/null differ diff --git a/examples/rhyme_2.png b/examples/rhyme_2.png deleted file mode 100644 index 6cf9bf8..0000000 Binary files a/examples/rhyme_2.png and /dev/null differ diff --git a/examples/story_1.png b/examples/story_1.png deleted file mode 100644 index 3eb6ccb..0000000 Binary files a/examples/story_1.png and /dev/null differ diff --git a/examples/story_2.png b/examples/story_2.png deleted file mode 100644 index 9d37142..0000000 Binary files a/examples/story_2.png and /dev/null differ diff --git a/examples/web_1.png b/examples/web_1.png deleted file mode 100644 index 8943842..0000000 Binary files a/examples/web_1.png and /dev/null differ diff --git a/examples/wop_1.png b/examples/wop_1.png deleted file mode 100644 index 88f37d6..0000000 Binary files a/examples/wop_1.png and /dev/null differ diff --git a/examples/wop_2.png b/examples/wop_2.png deleted file mode 100644 index 8255974..0000000 Binary files a/examples/wop_2.png and /dev/null differ diff --git a/examples_v2/2000x1372_wmkn_0012149409555.jpg b/examples_v2/2000x1372_wmkn_0012149409555.jpg deleted file mode 100755 index 1250f7f..0000000 Binary files a/examples_v2/2000x1372_wmkn_0012149409555.jpg and /dev/null differ diff --git a/examples_v2/KFC-20-for-20-Nuggets.jpg b/examples_v2/KFC-20-for-20-Nuggets.jpg deleted file mode 100755 index 0ec641c..0000000 Binary files a/examples_v2/KFC-20-for-20-Nuggets.jpg and /dev/null differ diff --git a/examples_v2/cockdial.png b/examples_v2/cockdial.png deleted file mode 100755 index 935f98e..0000000 Binary files a/examples_v2/cockdial.png and /dev/null differ diff --git a/examples_v2/float.png b/examples_v2/float.png deleted file mode 100755 index 900dcb0..0000000 Binary files a/examples_v2/float.png and /dev/null differ diff --git a/examples_v2/glip_test.jpg b/examples_v2/glip_test.jpg deleted file mode 100755 index f9198f2..0000000 Binary files a/examples_v2/glip_test.jpg and /dev/null differ diff --git a/examples_v2/office.jpg b/examples_v2/office.jpg deleted file mode 100755 index e35bdc2..0000000 Binary files a/examples_v2/office.jpg and /dev/null differ diff --git a/examples_v2/sofa.jpg b/examples_v2/sofa.jpg deleted file mode 100755 index 8610591..0000000 Binary files a/examples_v2/sofa.jpg and /dev/null differ diff --git a/examples_v2/thief.png b/examples_v2/thief.png deleted file mode 100755 index 579ee52..0000000 Binary files a/examples_v2/thief.png and /dev/null differ diff --git a/minigpt4/common/caption_tools/__init__.py b/minigpt4/common/caption_tools/__init__.py new file mode 100644 index 0000000..9b98da8 --- /dev/null +++ b/minigpt4/common/caption_tools/__init__.py @@ -0,0 +1,8 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" diff --git a/minigpt4/common/caption_tools/caption_utils.py b/minigpt4/common/caption_tools/caption_utils.py new file mode 100644 index 0000000..6a73c3d --- /dev/null +++ b/minigpt4/common/caption_tools/caption_utils.py @@ -0,0 +1,127 @@ + +from collections import defaultdict +from pycocoevalcap.eval import COCOEvalCap +import json + +class COCO_Annotation: + def __init__(self, annotation_file): + self.coco_cn_file = annotation_file + self.imgToAnns = self.build_imgToAnns() + + def build_imgToAnns(self): + imgToAnns = defaultdict(list) + with open(self.coco_cn_file, "r", encoding="UTF-8") as fin: + for line in fin: + line = line.strip() + temp = eval(line) + annotations = temp['annotations'] + for ann in annotations: + image_id = str(ann['image_id']).zfill(6) + imgToAnns[image_id].append({'image_id':image_id,'caption':ann['caption'],'image': ann['image_id']}) + return imgToAnns + + def getImgIds(self): + return self.imgToAnns.keys() + +class COCO_Result: + def __init__(self,result_file): + self.coco_cn_file = result_file + self.imgToAnns = self.build_imgToAnns() + + def build_imgToAnns(self): + imgToAnns = dict() + data = json.load(open(self.coco_cn_file, "r")) + for d in data: + tmp = { + 'image_id':d['question_id'][-6:], + 'caption':d['answer'] + } + imgToAnns[d['question_id'][-6:]] = [tmp] + return imgToAnns + +def coco_caption_eval(coco_gt_root, results_file, split_name): + files = { + "val":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val_gt.json", + "test":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test_gt.json" + } + + # create coco object and coco_result object + annotation_file = files[split_name] + coco = COCO_Annotation(annotation_file) + coco_result = COCO_Result(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # evaluate on a subset of images by setting + # coco_eval.params['image_id'] = coco_result.getImgIds() + # please remove this line when evaluating the full validation set + # coco_eval.params['image_id'] = coco_result.getImgIds() + + # evaluate results + # SPICE will take a few minutes the first time, but speeds up due to caching + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f"{metric}: {score:.3f}") + + return coco_eval + + + +class TextCap_Annotation: + def __init__(self, annotation_file): + self.anno_file = annotation_file + self.imgToAnns = self.build_imgToAnns() + + def build_imgToAnns(self): + imgToAnns = defaultdict(list) + annotations = json.load(open(self.anno_file,"r"))['data'] + for ann in annotations: + image_id = str(ann['image_name']) + imgToAnns[image_id].append({ + 'image_id':image_id, + # 'caption':ann['reference_strs'], + 'caption':ann['caption_str'], + 'image': ann['image_path'] + }) + return imgToAnns + + def getImgIds(self): + return self.imgToAnns.keys() + +class TextCap_Result: + def __init__(self,result_file): + self.result_file = result_file + self.imgToAnns = self.build_imgToAnns() + + def build_imgToAnns(self): + imgToAnns = dict() + data = json.load(open(self.result_file, "r")) + for d in data: + tmp = { + 'image_id':d['question_id'], # actually image_id + 'caption':d['answer'] + } + imgToAnns[d['question_id']] = [tmp] + return imgToAnns + + +def textcaps_caption_eval(annotation_file, results_file): + + # create coco object and coco_result object + anno = TextCap_Annotation(annotation_file) + result = TextCap_Result(results_file) + + # create coco_eval object by taking coco and coco_result + text_eval = COCOEvalCap(anno, result) + + # SPICE will take a few minutes the first time, but speeds up due to caching + text_eval.evaluate() + + # print output evaluation scores + for metric, score in text_eval.eval.items(): + print(f"{metric}: {score:.3f}") + + return text_eval diff --git a/minigpt4/common/config.py b/minigpt4/common/config.py index e184b1f..50c97d3 100644 --- a/minigpt4/common/config.py +++ b/minigpt4/common/config.py @@ -29,6 +29,7 @@ class Config: runner_config = self.build_runner_config(config) model_config = self.build_model_config(config, **user_config) dataset_config = self.build_dataset_config(config) + evaluation_dataset_config = self.build_evaluation_dataset_config(config) # Validate the user-provided runner configuration # model and dataset configuration are supposed to be validated by the respective classes @@ -37,7 +38,7 @@ class Config: # Override the default configuration with user options. self.config = OmegaConf.merge( - runner_config, model_config, dataset_config, user_config + runner_config, model_config, dataset_config, evaluation_dataset_config, user_config ) def _validate_runner_config(self, runner_config): @@ -111,6 +112,29 @@ class Config: return dataset_config + @staticmethod + def build_evaluation_dataset_config(config): + # from Minigpt-v2 + datasets = config.get("evaluation_datasets", None) + # if datasets is None: + # raise KeyError( + # "Expecting 'datasets' as the root key for dataset configuration." + # ) + + dataset_config = OmegaConf.create() + + if datasets is not None: + for dataset_name in datasets: + builder_cls = registry.get_builder_class(dataset_name) + + # hierarchy override, customized config > default config + dataset_config = OmegaConf.merge( + dataset_config, + {"evaluation_datasets": {dataset_name: config["evaluation_datasets"][dataset_name]}}, + ) + + return dataset_config + def _convert_to_dot_list(self, opts): if opts is None: opts = [] @@ -136,6 +160,10 @@ class Config: def datasets_cfg(self): return self.config.datasets + @property + def evaluation_datasets_cfg(self): + return self.config.evaluation_datasets + @property def model_cfg(self): return self.config.model diff --git a/minigpt4/common/eval_utils.py b/minigpt4/common/eval_utils.py new file mode 100644 index 0000000..3087d2a --- /dev/null +++ b/minigpt4/common/eval_utils.py @@ -0,0 +1,76 @@ +import argparse +import numpy as np +from nltk.translate.bleu_score import sentence_bleu + +from minigpt4.common.registry import registry +from minigpt4.common.config import Config + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + + +def eval_parser(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument("--cfg-path", required=True, help="path to configuration file.") + parser.add_argument("--name", type=str, default='A2', help="evaluation name") + parser.add_argument("--ckpt", type=str, help="path to configuration file.") + parser.add_argument("--eval_opt", type=str, default='all', help="path to configuration file.") + parser.add_argument("--max_new_tokens", type=int, default=10, help="max number of generated tokens") + parser.add_argument("--batch_size", type=int, default=32) + parser.add_argument("--lora_r", type=int, default=64, help="lora rank of the model") + parser.add_argument("--lora_alpha", type=int, default=16, help="lora alpha") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + return parser + + +def prepare_texts(texts, conv_temp): + convs = [conv_temp.copy() for _ in range(len(texts))] + [conv.append_message( + conv.roles[0], ' {}'.format(text)) for conv, text in zip(convs, texts)] + [conv.append_message(conv.roles[1], None) for conv in convs] + texts = [conv.get_prompt() for conv in convs] + return texts + + +def init_model(args): + print('Initialization Model') + cfg = Config(args) + # cfg.model_cfg.ckpt = args.ckpt + # cfg.model_cfg.lora_r = args.lora_r + # cfg.model_cfg.lora_alpha = args.lora_alpha + + model_config = cfg.model_cfg + model_cls = registry.get_model_class(model_config.arch) + model = model_cls.from_config(model_config).to('cuda:0') + +# import pudb; pudb.set_trace() + key = list(cfg.datasets_cfg.keys())[0] + vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train + vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + print('Initialization Finished') + return model, vis_processor + +def computeIoU(bbox1, bbox2): + x1, y1, x2, y2 = bbox1 + x3, y3, x4, y4 = bbox2 + intersection_x1 = max(x1, x3) + intersection_y1 = max(y1, y3) + intersection_x2 = min(x2, x4) + intersection_y2 = min(y2, y4) + intersection_area = max(0, intersection_x2 - intersection_x1 + 1) * max(0, intersection_y2 - intersection_y1 + 1) + bbox1_area = (x2 - x1 + 1) * (y2 - y1 + 1) + bbox2_area = (x4 - x3 + 1) * (y4 - y3 + 1) + union_area = bbox1_area + bbox2_area - intersection_area + iou = intersection_area / union_area + return iou diff --git a/minigpt4/common/logger.py b/minigpt4/common/logger.py index 9a5a727..4a6a8ed 100644 --- a/minigpt4/common/logger.py +++ b/minigpt4/common/logger.py @@ -2,13 +2,14 @@ Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause - For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import datetime import logging import time from collections import defaultdict, deque +from torch.utils.tensorboard import SummaryWriter import torch import torch.distributed as dist @@ -80,9 +81,10 @@ class SmoothedValue(object): class MetricLogger(object): - def __init__(self, delimiter="\t"): + def __init__(self, delimiter="\t",writer: SummaryWriter=None): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter + self.writer = writer def update(self, **kwargs): for k, v in kwargs.items(): @@ -90,6 +92,10 @@ class MetricLogger(object): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) + + def update_writer(self, it): + for name, meter in self.meters.items(): + self.writer.add_scalar(name, meter, ) def __getattr__(self, attr): if attr in self.meters: diff --git a/minigpt4/common/vqa_tools/__init__.py b/minigpt4/common/vqa_tools/__init__.py new file mode 100644 index 0000000..9b98da8 --- /dev/null +++ b/minigpt4/common/vqa_tools/__init__.py @@ -0,0 +1,8 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" diff --git a/minigpt4/common/vqa_tools/vqa.py b/minigpt4/common/vqa_tools/vqa.py new file mode 100644 index 0000000..a386b90 --- /dev/null +++ b/minigpt4/common/vqa_tools/vqa.py @@ -0,0 +1,211 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +__author__ = "aagrawal" +__version__ = "0.9" + +# Interface for accessing the VQA dataset. + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). + +# The following functions are defined: +# VQA - VQA class that loads VQA annotation file and prepares data structures. +# getQuesIds - Get question ids that satisfy given filter conditions. +# getImgIds - Get image ids that satisfy given filter conditions. +# loadQA - Load questions and answers with the specified question ids. +# showQA - Display the specified questions and answers. +# loadRes - Load result file and create result object. + +# Help on each function can be accessed by: "help(COCO.function)" + +import json +import datetime +import copy + + +class VQA: + def __init__(self, annotation_file=None, question_file=None): + """ + Constructor of VQA helper class for reading and visualizing questions and answers. + :param annotation_file (str): location of VQA annotation file + :return: + """ + # load dataset + self.dataset = {} + self.questions = {} + self.qa = {} + self.qqa = {} + self.imgToQA = {} + if not annotation_file == None and not question_file == None: + print("loading VQA annotations and questions into memory...") + time_t = datetime.datetime.utcnow() + dataset = json.load(open(annotation_file, "r")) + questions = json.load(open(question_file, "r")) + self.dataset = dataset + self.questions = questions + self.createIndex() + + def createIndex(self): + # create index + print("creating index...") + imgToQA = {ann["image_id"]: [] for ann in self.dataset["annotations"]} + qa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + qqa = {ann["question_id"]: [] for ann in self.dataset["annotations"]} + for ann in self.dataset["annotations"]: + imgToQA[ann["image_id"]] += [ann] + qa[ann["question_id"]] = ann + for ques in self.questions["questions"]: + qqa[ques["question_id"]] = ques + print("index created!") + + # create class members + self.qa = qa + self.qqa = qqa + self.imgToQA = imgToQA + + def info(self): + """ + Print information about the VQA annotation file. + :return: + """ + for key, value in self.datset["info"].items(): + print("%s: %s" % (key, value)) + + def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): + """ + Get question ids that satisfy given filter conditions. default skips that filter + :param imgIds (int array) : get question ids for given imgs + quesTypes (str array) : get question ids for given question types + ansTypes (str array) : get question ids for given answer types + :return: ids (int array) : integer array of question ids + """ + imgIds = imgIds if type(imgIds) == list else [imgIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(imgIds) == 0: + anns = sum( + [self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA], + [], + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["question_id"] for ann in anns] + return ids + + def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): + """ + Get image ids that satisfy given filter conditions. default skips that filter + :param quesIds (int array) : get image ids for given question ids + quesTypes (str array) : get image ids for given question types + ansTypes (str array) : get image ids for given answer types + :return: ids (int array) : integer array of image ids + """ + quesIds = quesIds if type(quesIds) == list else [quesIds] + quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] + ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] + + if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: + anns = self.dataset["annotations"] + else: + if not len(quesIds) == 0: + anns = sum( + [self.qa[quesId] for quesId in quesIds if quesId in self.qa], [] + ) + else: + anns = self.dataset["annotations"] + anns = ( + anns + if len(quesTypes) == 0 + else [ann for ann in anns if ann["question_type"] in quesTypes] + ) + anns = ( + anns + if len(ansTypes) == 0 + else [ann for ann in anns if ann["answer_type"] in ansTypes] + ) + ids = [ann["image_id"] for ann in anns] + return ids + + def loadQA(self, ids=[]): + """ + Load questions and answers with the specified question ids. + :param ids (int array) : integer ids specifying question ids + :return: qa (object array) : loaded qa objects + """ + if type(ids) == list: + return [self.qa[id] for id in ids] + elif type(ids) == int: + return [self.qa[ids]] + + def showQA(self, anns): + """ + Display the specified annotations. + :param anns (array of object): annotations to display + :return: None + """ + if len(anns) == 0: + return 0 + for ann in anns: + quesId = ann["question_id"] + print("Question: %s" % (self.qqa[quesId]["question"])) + for ans in ann["answers"]: + print("Answer %d: %s" % (ans["answer_id"], ans["answer"])) + + def loadRes(self, resFile, quesFile): + """ + Load result file and return a result object. + :param resFile (str) : file name of result file + :return: res (obj) : result api object + """ + res = VQA() + res.questions = json.load(open(quesFile)) + res.dataset["info"] = copy.deepcopy(self.questions["info"]) + res.dataset["task_type"] = copy.deepcopy(self.questions["task_type"]) + res.dataset["data_type"] = copy.deepcopy(self.questions["data_type"]) + res.dataset["data_subtype"] = copy.deepcopy(self.questions["data_subtype"]) + res.dataset["license"] = copy.deepcopy(self.questions["license"]) + + print("Loading and preparing results... ") + time_t = datetime.datetime.utcnow() + anns = json.load(open(resFile)) + assert type(anns) == list, "results is not an array of objects" + annsQuesIds = [ann["question_id"] for ann in anns] + assert set(annsQuesIds) == set( + self.getQuesIds() + ), "Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file." + for ann in anns: + quesId = ann["question_id"] + if res.dataset["task_type"] == "Multiple Choice": + assert ( + ann["answer"] in self.qqa[quesId]["multiple_choices"] + ), "predicted answer is not one of the multiple choices" + qaAnn = self.qa[quesId] + ann["image_id"] = qaAnn["image_id"] + ann["question_type"] = qaAnn["question_type"] + ann["answer_type"] = qaAnn["answer_type"] + print( + "DONE (t=%0.2fs)" % ((datetime.datetime.utcnow() - time_t).total_seconds()) + ) + + res.dataset["annotations"] = anns + res.createIndex() + return res diff --git a/minigpt4/common/vqa_tools/vqa_eval.py b/minigpt4/common/vqa_tools/vqa_eval.py new file mode 100644 index 0000000..5ab95d2 --- /dev/null +++ b/minigpt4/common/vqa_tools/vqa_eval.py @@ -0,0 +1,324 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +# coding=utf-8 + +__author__ = "aagrawal" + +# This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: +# (https://github.com/tylin/coco-caption/blob/master/pycocoevalcap/eval.py). +import sys +import re + + +class VQAEval: + def __init__(self, vqa=None, vqaRes=None, n=2): + self.n = n + self.accuracy = {} + self.evalQA = {} + self.evalQuesType = {} + self.evalAnsType = {} + self.vqa = vqa # annotation + self.vqaRes = vqaRes # predict answers + if vqa is not None: + self.params = {"question_id": vqa.getQuesIds()} + self.contractions = { + "aint": "ain't", + "arent": "aren't", + "cant": "can't", + "couldve": "could've", + "couldnt": "couldn't", + "couldn'tve": "couldn't've", + "couldnt've": "couldn't've", + "didnt": "didn't", + "doesnt": "doesn't", + "dont": "don't", + "hadnt": "hadn't", + "hadnt've": "hadn't've", + "hadn'tve": "hadn't've", + "hasnt": "hasn't", + "havent": "haven't", + "hed": "he'd", + "hed've": "he'd've", + "he'dve": "he'd've", + "hes": "he's", + "howd": "how'd", + "howll": "how'll", + "hows": "how's", + "Id've": "I'd've", + "I'dve": "I'd've", + "Im": "I'm", + "Ive": "I've", + "isnt": "isn't", + "itd": "it'd", + "itd've": "it'd've", + "it'dve": "it'd've", + "itll": "it'll", + "let's": "let's", + "maam": "ma'am", + "mightnt": "mightn't", + "mightnt've": "mightn't've", + "mightn'tve": "mightn't've", + "mightve": "might've", + "mustnt": "mustn't", + "mustve": "must've", + "neednt": "needn't", + "notve": "not've", + "oclock": "o'clock", + "oughtnt": "oughtn't", + "ow's'at": "'ow's'at", + "'ows'at": "'ow's'at", + "'ow'sat": "'ow's'at", + "shant": "shan't", + "shed've": "she'd've", + "she'dve": "she'd've", + "she's": "she's", + "shouldve": "should've", + "shouldnt": "shouldn't", + "shouldnt've": "shouldn't've", + "shouldn'tve": "shouldn't've", + "somebody'd": "somebodyd", + "somebodyd've": "somebody'd've", + "somebody'dve": "somebody'd've", + "somebodyll": "somebody'll", + "somebodys": "somebody's", + "someoned": "someone'd", + "someoned've": "someone'd've", + "someone'dve": "someone'd've", + "someonell": "someone'll", + "someones": "someone's", + "somethingd": "something'd", + "somethingd've": "something'd've", + "something'dve": "something'd've", + "somethingll": "something'll", + "thats": "that's", + "thered": "there'd", + "thered've": "there'd've", + "there'dve": "there'd've", + "therere": "there're", + "theres": "there's", + "theyd": "they'd", + "theyd've": "they'd've", + "they'dve": "they'd've", + "theyll": "they'll", + "theyre": "they're", + "theyve": "they've", + "twas": "'twas", + "wasnt": "wasn't", + "wed've": "we'd've", + "we'dve": "we'd've", + "weve": "we've", + "werent": "weren't", + "whatll": "what'll", + "whatre": "what're", + "whats": "what's", + "whatve": "what've", + "whens": "when's", + "whered": "where'd", + "wheres": "where's", + "whereve": "where've", + "whod": "who'd", + "whod've": "who'd've", + "who'dve": "who'd've", + "wholl": "who'll", + "whos": "who's", + "whove": "who've", + "whyll": "why'll", + "whyre": "why're", + "whys": "why's", + "wont": "won't", + "wouldve": "would've", + "wouldnt": "wouldn't", + "wouldnt've": "wouldn't've", + "wouldn'tve": "wouldn't've", + "yall": "y'all", + "yall'll": "y'all'll", + "y'allll": "y'all'll", + "yall'd've": "y'all'd've", + "y'alld've": "y'all'd've", + "y'all'dve": "y'all'd've", + "youd": "you'd", + "youd've": "you'd've", + "you'dve": "you'd've", + "youll": "you'll", + "youre": "you're", + "youve": "you've", + } + self.manualMap = { + "none": "0", + "zero": "0", + "one": "1", + "two": "2", + "three": "3", + "four": "4", + "five": "5", + "six": "6", + "seven": "7", + "eight": "8", + "nine": "9", + "ten": "10", + } + self.articles = ["a", "an", "the"] + + self.periodStrip = re.compile("(?!<=\d)(\.)(?!\d)") + self.commaStrip = re.compile("(\d)(,)(\d)") + self.punct = [ + ";", + r"/", + "[", + "]", + '"', + "{", + "}", + "(", + ")", + "=", + "+", + "\\", + "_", + "-", + ">", + "<", + "@", + "`", + ",", + "?", + "!", + ] + + def evaluate(self, quesIds=None): + if quesIds == None: + quesIds = [quesId for quesId in self.params["question_id"]] + gts = {} + res = {} + for quesId in quesIds: + gts[quesId] = self.vqa.qa[quesId] + res[quesId] = self.vqaRes.qa[quesId] + + # ================================================= + # Compute accuracy + # ================================================= + accQA = [] + accQuesType = {} + accAnsType = {} + print("computing accuracy") + step = 0 + for quesId in quesIds: + resAns = res[quesId]["answer"] + resAns = resAns.replace("\n", " ") + resAns = resAns.replace("\t", " ") + resAns = resAns.strip() + resAns = self.processPunctuation(resAns) + resAns = self.processDigitArticle(resAns) + gtAcc = [] + gtAnswers = [ans["answer"] for ans in gts[quesId]["answers"]] + if len(set(gtAnswers)) > 1: + for ansDic in gts[quesId]["answers"]: + ansDic["answer"] = self.processPunctuation(ansDic["answer"]) + for gtAnsDatum in gts[quesId]["answers"]: + otherGTAns = [ + item for item in gts[quesId]["answers"] if item != gtAnsDatum + ] + matchingAns = [item for item in otherGTAns if item["answer"] == resAns] + acc = min(1, float(len(matchingAns)) / 3) + gtAcc.append(acc) + quesType = gts[quesId]["question_type"] + ansType = gts[quesId]["answer_type"] + avgGTAcc = float(sum(gtAcc)) / len(gtAcc) + accQA.append(avgGTAcc) + if quesType not in accQuesType: + accQuesType[quesType] = [] + accQuesType[quesType].append(avgGTAcc) + if ansType not in accAnsType: + accAnsType[ansType] = [] + accAnsType[ansType].append(avgGTAcc) + self.setEvalQA(quesId, avgGTAcc) + self.setEvalQuesType(quesId, quesType, avgGTAcc) + self.setEvalAnsType(quesId, ansType, avgGTAcc) + if step % 100 == 0: + self.updateProgress(step / float(len(quesIds))) + step = step + 1 + + self.setAccuracy(accQA, accQuesType, accAnsType) + print("Done computing accuracy") + + def processPunctuation(self, inText): + outText = inText + for p in self.punct: + if (p + " " in inText or " " + p in inText) or ( + re.search(self.commaStrip, inText) != None + ): + outText = outText.replace(p, "") + else: + outText = outText.replace(p, " ") + outText = self.periodStrip.sub("", outText, re.UNICODE) + return outText + + def processDigitArticle(self, inText): + outText = [] + tempText = inText.lower().split() + for word in tempText: + word = self.manualMap.setdefault(word, word) + if word not in self.articles: + outText.append(word) + else: + pass + for wordId, word in enumerate(outText): + if word in self.contractions: + outText[wordId] = self.contractions[word] + outText = " ".join(outText) + return outText + + def setAccuracy(self, accQA, accQuesType, accAnsType): + self.accuracy["overall"] = round(100 * float(sum(accQA)) / len(accQA), self.n) + self.accuracy["perQuestionType"] = { + quesType: round( + 100 * float(sum(accQuesType[quesType])) / len(accQuesType[quesType]), + self.n, + ) + for quesType in accQuesType + } + self.accuracy["perAnswerType"] = { + ansType: round( + 100 * float(sum(accAnsType[ansType])) / len(accAnsType[ansType]), self.n + ) + for ansType in accAnsType + } + + def setEvalQA(self, quesId, acc): + self.evalQA[quesId] = round(100 * acc, self.n) + + def setEvalQuesType(self, quesId, quesType, acc): + if quesType not in self.evalQuesType: + self.evalQuesType[quesType] = {} + self.evalQuesType[quesType][quesId] = round(100 * acc, self.n) + + def setEvalAnsType(self, quesId, ansType, acc): + if ansType not in self.evalAnsType: + self.evalAnsType[ansType] = {} + self.evalAnsType[ansType][quesId] = round(100 * acc, self.n) + + def updateProgress(self, progress): + barLength = 20 + status = "" + if isinstance(progress, int): + progress = float(progress) + if not isinstance(progress, float): + progress = 0 + status = "error: progress var must be float\r\n" + if progress < 0: + progress = 0 + status = "Halt...\r\n" + if progress >= 1: + progress = 1 + status = "Done...\r\n" + block = int(round(barLength * progress)) + text = "\rFinshed Percent: [{0}] {1}% {2}".format( + "#" * block + "-" * (barLength - block), int(progress * 100), status + ) + sys.stdout.write(text) + sys.stdout.flush() diff --git a/minigpt4/configs/datasets/aokvqa/defaults.yaml b/minigpt4/configs/datasets/aokvqa/defaults.yaml index 767fdd4..4098696 100755 --- a/minigpt4/configs/datasets/aokvqa/defaults.yaml +++ b/minigpt4/configs/datasets/aokvqa/defaults.yaml @@ -15,6 +15,16 @@ datasets: url: - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_train.json storage: - - /path/to/aokvqa_v1p0_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/AOKVQA/aokvqa_v1p0_train.json + val: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/AOKVQA/aokvqa_v1p0_val.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/aokvqa/aokvqa_v1p0_val.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/AOKVQA/aokvqa_v1p0_val.json images: - storage: /path/to/coco/images \ No newline at end of file + storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco/caption.yaml b/minigpt4/configs/datasets/coco/caption.yaml index ac072a4..8d62c89 100644 --- a/minigpt4/configs/datasets/coco/caption.yaml +++ b/minigpt4/configs/datasets/coco/caption.yaml @@ -14,8 +14,18 @@ datasets: annotations: train: url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json - md5: aa31ac474cf6250ebb81d18348a07ed8 - storage: /path/to/coco_caption/coco_karpathy_train.json + # md5: aa31ac474cf6250ebb81d18348a07ed8 + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_train.json + val: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json + images: - storage: /path/to/coco/images + storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO diff --git a/minigpt4/configs/datasets/coco/caption_eval.yaml b/minigpt4/configs/datasets/coco/caption_eval.yaml new file mode 100644 index 0000000..5a2a17f --- /dev/null +++ b/minigpt4/configs/datasets/coco/caption_eval.yaml @@ -0,0 +1,26 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + coco_caption: # name of the dataset builder + # dataset_card: dataset_card/coco_caption.md + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + val: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val.json + test: + url: https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test.json + + images: + storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO + diff --git a/minigpt4/configs/datasets/coco/defaults_vqa.yaml b/minigpt4/configs/datasets/coco/defaults_vqa.yaml index 457e0a3..9a62dbe 100755 --- a/minigpt4/configs/datasets/coco/defaults_vqa.yaml +++ b/minigpt4/configs/datasets/coco/defaults_vqa.yaml @@ -13,12 +13,36 @@ datasets: annotations: train: url: - - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json - - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json # 443752 + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json # 214352 storage: - - /path/to/vqav2/vqa_train.json - - /path/to/vqav2/vqa_val.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/test_part/vqa_train_part100.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/test_part/vqa_val_part100.json + val: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val_eval.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/test_part/vqa_val_eval_part100.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_mscoco_val2014_annotations.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_test.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_test.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/test_part/vqa_test_part100.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json + images: - storage: /path/to/coco/images + storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco/defaults_vqa_eval.yaml b/minigpt4/configs/datasets/coco/defaults_vqa_eval.yaml new file mode 100755 index 0000000..7943d6a --- /dev/null +++ b/minigpt4/configs/datasets/coco/defaults_vqa_eval.yaml @@ -0,0 +1,39 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + coco_vqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + + annotations: + val: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val_eval.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_mscoco_val2014_annotations.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_test.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_test.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json + + images: + storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO + + \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco/defaults_vqa_part.yaml b/minigpt4/configs/datasets/coco/defaults_vqa_part.yaml new file mode 100755 index 0000000..e538aba --- /dev/null +++ b/minigpt4/configs/datasets/coco/defaults_vqa_part.yaml @@ -0,0 +1,48 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + coco_vqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + + annotations: + train: + url: + # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_train.json # 443752 + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val.json # 214352 + storage: + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/test_part/vqa_train_part100.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/test_part/vqa_val_part100.json + val: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/v2_mscoco_val2014_annotations.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_val_eval.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/test_part/vqa_val_eval_part100.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_OpenEnded_mscoco_val2014_questions.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/v2_mscoco_val2014_annotations.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/vqa_test.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/vqav2/answer_list.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/vqa_test.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/test_part/vqa_test_part100.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VQAv2/answer_list.json + + images: + storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO + + \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml b/minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml index 8325efc..e9768d5 100755 --- a/minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml +++ b/minigpt4/configs/datasets/coco_bbox/invrefcoco.yaml @@ -2,7 +2,7 @@ datasets: invrefcoco: data_type: images build_info: - image_path: /path/to/coco/images + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO ann_path: /path/to/refcoco_annotations dataset: invrefcoco splitBy: unc \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml b/minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml index e562240..63d6b88 100755 --- a/minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml +++ b/minigpt4/configs/datasets/coco_bbox/invrefcocog.yaml @@ -2,7 +2,7 @@ datasets: invrefcocog: data_type: images build_info: - image_path: /path/to/coco/images + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO ann_path: /path/to/refcoco_annotations dataset: invrefcocog splitBy: umd \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml b/minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml index 1c57c8e..e044884 100755 --- a/minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml +++ b/minigpt4/configs/datasets/coco_bbox/invrefcocop.yaml @@ -2,7 +2,7 @@ datasets: invrefcocop: data_type: images build_info: - image_path: /path/to/coco/images + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO ann_path: /path/to/refcoco_annotations dataset: invrefcoco+ splitBy: unc \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco_bbox/refcoco.yaml b/minigpt4/configs/datasets/coco_bbox/refcoco.yaml index fc96f6d..00f268f 100755 --- a/minigpt4/configs/datasets/coco_bbox/refcoco.yaml +++ b/minigpt4/configs/datasets/coco_bbox/refcoco.yaml @@ -2,7 +2,7 @@ datasets: refcoco: data_type: images build_info: - image_path: /path/to/coco/images + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO ann_path: /path/to/refcoco_annotations dataset: refcoco splitBy: unc \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco_bbox/refcocog.yaml b/minigpt4/configs/datasets/coco_bbox/refcocog.yaml index bb751cb..39ef142 100755 --- a/minigpt4/configs/datasets/coco_bbox/refcocog.yaml +++ b/minigpt4/configs/datasets/coco_bbox/refcocog.yaml @@ -2,7 +2,7 @@ datasets: refcocog: data_type: images build_info: - image_path: /path/to/coco/images + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO ann_path: /path/to/refcoco_annotations dataset: refcocog splitBy: umd \ No newline at end of file diff --git a/minigpt4/configs/datasets/coco_bbox/refcocop.yaml b/minigpt4/configs/datasets/coco_bbox/refcocop.yaml index 36c574e..ae062d6 100755 --- a/minigpt4/configs/datasets/coco_bbox/refcocop.yaml +++ b/minigpt4/configs/datasets/coco_bbox/refcocop.yaml @@ -2,7 +2,7 @@ datasets: refcocop: data_type: images build_info: - image_path: /path/to/coco/images + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO ann_path: /path/to/refcoco_annotations dataset: refcoco+ splitBy: unc \ No newline at end of file diff --git a/minigpt4/configs/datasets/gqa/balanced_sft_raw.yaml b/minigpt4/configs/datasets/gqa/balanced_sft_raw.yaml new file mode 100644 index 0000000..73c84a9 --- /dev/null +++ b/minigpt4/configs/datasets/gqa/balanced_sft_raw.yaml @@ -0,0 +1,30 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + gqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/train_balanced_questions.json + val: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/testdev_balanced_questions.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/test_balanced_questions.json + images: + storage: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images/ diff --git a/minigpt4/configs/datasets/gqa/balanced_val.yaml b/minigpt4/configs/datasets/gqa/balanced_sft_raw_eval.yaml similarity index 53% rename from minigpt4/configs/datasets/gqa/balanced_val.yaml rename to minigpt4/configs/datasets/gqa/balanced_sft_raw_eval.yaml index f4c8765..2960164 100644 --- a/minigpt4/configs/datasets/gqa/balanced_val.yaml +++ b/minigpt4/configs/datasets/gqa/balanced_sft_raw_eval.yaml @@ -11,11 +11,15 @@ datasets: build_info: # Be careful not to append minus sign (-) before split to avoid itemizing annotations: - train: + val: url: - - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json storage: - - /path/to/gqa/train_balanced_questions.json - + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/testdev_balanced_questions.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/test_balanced_questions.json images: - storage: /path/to/gqa/images + storage: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images/ diff --git a/minigpt4/configs/datasets/gqa/balanced_sft_raw_part.yaml b/minigpt4/configs/datasets/gqa/balanced_sft_raw_part.yaml new file mode 100644 index 0000000..8f912d7 --- /dev/null +++ b/minigpt4/configs/datasets/gqa/balanced_sft_raw_part.yaml @@ -0,0 +1,30 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + gqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + train: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/train_balanced_questions_90k.json + val: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/testdev_balanced_questions.json + test: + url: + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/test_balanced_questions.json + images: + storage: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images/ diff --git a/minigpt4/configs/datasets/llava/conversation.yaml b/minigpt4/configs/datasets/llava/conversation.yaml index 35c327b..1b48d7c 100755 --- a/minigpt4/configs/datasets/llava/conversation.yaml +++ b/minigpt4/configs/datasets/llava/conversation.yaml @@ -3,5 +3,6 @@ datasets: llava_conversation: data_type: images build_info: - image_path: /path/to/coco/images - ann_path: /path/to/llava/conversation_58k.json \ No newline at end of file + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO/train2014 + # ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/multimodal-sft/llava_150k/en/conversation_58k.json + ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/multimodal-sft/llava_150k/en/llava_conver_single_turn_257k_clean_v2.json \ No newline at end of file diff --git a/minigpt4/configs/datasets/llava/detail.yaml b/minigpt4/configs/datasets/llava/detail.yaml index 896df39..99b7bef 100755 --- a/minigpt4/configs/datasets/llava/detail.yaml +++ b/minigpt4/configs/datasets/llava/detail.yaml @@ -2,5 +2,5 @@ datasets: llava_detail: data_type: images build_info: - image_path: /path/to/coco/images - ann_path: /path/to/llava/detail_23k.json \ No newline at end of file + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO/train2014 + ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/multimodal-sft/llava_150k/en/detail_23k.json \ No newline at end of file diff --git a/minigpt4/configs/datasets/llava/mix.yaml b/minigpt4/configs/datasets/llava/mix.yaml new file mode 100755 index 0000000..e53b488 --- /dev/null +++ b/minigpt4/configs/datasets/llava/mix.yaml @@ -0,0 +1,12 @@ +datasets: + + llava_mix: + data_type: images + build_info: + ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/llava_v1_5_mix665k/llava_v1_5_mix665k.json + image_path_coco: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO/train2014 + image_path_gqa: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images + image_path_ocr: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OCRVQA/images + image_path_text: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextVQA/train_images + # image_path_vg: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VG + diff --git a/minigpt4/configs/datasets/llava/pretrain_cap.yaml b/minigpt4/configs/datasets/llava/pretrain_cap.yaml new file mode 100755 index 0000000..34986ec --- /dev/null +++ b/minigpt4/configs/datasets/llava/pretrain_cap.yaml @@ -0,0 +1,6 @@ +datasets: + llava_pretrain: + data_type: images + build_info: + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/llava-cc3m-595k/images + ann_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/llava-cc3m-595k/chat.json \ No newline at end of file diff --git a/minigpt4/configs/datasets/llava/reason.yaml b/minigpt4/configs/datasets/llava/reason.yaml index f5fb674..e12c31b 100755 --- a/minigpt4/configs/datasets/llava/reason.yaml +++ b/minigpt4/configs/datasets/llava/reason.yaml @@ -3,5 +3,5 @@ datasets: llava_reason: data_type: images build_info: - image_path: /path/to/coco/images - ann_path: /path/to/llava/complex_reasoning_77k.json \ No newline at end of file + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO/train2014 + ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/multimodal-sft/llava_150k/en/complex_reasoning_77k.json \ No newline at end of file diff --git a/minigpt4/configs/datasets/mix_vqa/mix_vqa.yaml b/minigpt4/configs/datasets/mix_vqa/mix_vqa.yaml new file mode 100755 index 0000000..71e7476 --- /dev/null +++ b/minigpt4/configs/datasets/mix_vqa/mix_vqa.yaml @@ -0,0 +1,10 @@ +datasets: + + llava_mix: + data_type: images + build_info: + ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/llava_v1_5_mix665k/mix_coco_gqa_162k.json + image_path_coco: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO/train2014 + image_path_gqa: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images + image_path_ocr: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OCRVQA/images + image_path_text: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextVQA/train_images \ No newline at end of file diff --git a/minigpt4/configs/datasets/multitask_conversation/default.yaml b/minigpt4/configs/datasets/multitask_conversation/default.yaml index 9d5ee72..f7a1e01 100644 --- a/minigpt4/configs/datasets/multitask_conversation/default.yaml +++ b/minigpt4/configs/datasets/multitask_conversation/default.yaml @@ -3,5 +3,5 @@ datasets: data_type: images build_info: - image_path: /path/to/coco/images - ann_path: /path/to/multitask_conversation/multi_task_conversation.json \ No newline at end of file + image_path: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO/train2014 + ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/minigptv2_llava_multitask_conv/multitask_conversation.json \ No newline at end of file diff --git a/minigpt4/configs/datasets/nlp/unnatural_instruction.yaml b/minigpt4/configs/datasets/nlp/unnatural_instruction.yaml index 67464e5..3217018 100644 --- a/minigpt4/configs/datasets/nlp/unnatural_instruction.yaml +++ b/minigpt4/configs/datasets/nlp/unnatural_instruction.yaml @@ -2,4 +2,4 @@ datasets: unnatural_instruction: data_type: text build_info: - ann_path: /path/to/unnatural_instructions/filtered_unnatural_instruction.json \ No newline at end of file + ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/unnatural_instructions/filtered_unnatural_instruction.json \ No newline at end of file diff --git a/minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml b/minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml index 0b651dd..153d010 100755 --- a/minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml +++ b/minigpt4/configs/datasets/ocrvqa/ocrvqa.yaml @@ -2,5 +2,5 @@ datasets: ocrvqa: data_type: images build_info: - image_path: /path/to/ocrvqa/images - ann_path: /path/to/ocrvqa/dataset.json \ No newline at end of file + image_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OCRVQA/image # 207572 + ann_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OCRVQA/dataset.json \ No newline at end of file diff --git a/minigpt4/configs/datasets/okvqa/defaults.yaml b/minigpt4/configs/datasets/okvqa/defaults.yaml index e536366..d72cf55 100755 --- a/minigpt4/configs/datasets/okvqa/defaults.yaml +++ b/minigpt4/configs/datasets/okvqa/defaults.yaml @@ -15,7 +15,38 @@ datasets: url: # TODO make this order insensitive - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_train.json + # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_train2014_questions.json + # - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_train2014_annotations.json storage: - - /path/to/okvqa/okvqa_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train_part100.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_train2014_questions.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_train2014_annotations.json + val: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval_part100.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json + test: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval_part100.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json images: - storage: /path/to/coco/images \ No newline at end of file + storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO \ No newline at end of file diff --git a/minigpt4/configs/datasets/okvqa/eval.yaml b/minigpt4/configs/datasets/okvqa/eval.yaml new file mode 100755 index 0000000..d58c446 --- /dev/null +++ b/minigpt4/configs/datasets/okvqa/eval.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +datasets: + ok_vqa: + # data_dir: ${env.data_dir}/datasets + data_type: images # [images|videos|features] + + build_info: + # Be careful not to append minus sign (-) before split to avoid itemizing + annotations: + val: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json + test: + url: + # TODO make this order insensitive + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_val_eval.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/okvqa_answer_list_train.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/OpenEnded_mscoco_val2014_questions.json + - https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/okvqa/mscoco_val2014_annotations.json + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_train.json + # - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_val_eval_part100.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/okvqa_answer_list_train.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/OpenEnded_mscoco_val2014_questions.json + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/OKVQA/mscoco_val2014_annotations.json + images: + storage: /mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO \ No newline at end of file diff --git a/minigpt4/configs/datasets/textcaps/caption.yaml b/minigpt4/configs/datasets/textcaps/caption.yaml index 9a732b4..010a4cc 100755 --- a/minigpt4/configs/datasets/textcaps/caption.yaml +++ b/minigpt4/configs/datasets/textcaps/caption.yaml @@ -3,7 +3,15 @@ datasets: data_type: images build_info: - image_path: /path/to/textcaps/train_images - ann_path: /path/to/textcaps/TextCaps_0.1_train.json - - + annotations: + train: + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextCap/TextCaps_0.1_train.json + val: + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextCap/TextCaps_0.1_val.json + test: + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextCap/TextCaps_0.1_test.json + images: + storage: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextVQA diff --git a/minigpt4/configs/datasets/textvqa/vqa.yaml b/minigpt4/configs/datasets/textvqa/vqa.yaml new file mode 100755 index 0000000..32f4de3 --- /dev/null +++ b/minigpt4/configs/datasets/textvqa/vqa.yaml @@ -0,0 +1,17 @@ +datasets: + text_vqa: + data_type: images + + build_info: + annotations: + train: + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextVQA/TextVQA_0.5.1_train.json + val: + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextVQA/TextVQA_0.5.1_val.json + test: + storage: + - /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextVQA/TextVQA_0.5.1_test.json + images: + storage: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextVQA diff --git a/minigpt4/configs/models/blip2/blip2_instruct_flant5xxl_prompt_moe.yaml b/minigpt4/configs/models/blip2/blip2_instruct_flant5xxl_prompt_moe.yaml new file mode 100644 index 0000000..b039ad8 --- /dev/null +++ b/minigpt4/configs/models/blip2/blip2_instruct_flant5xxl_prompt_moe.yaml @@ -0,0 +1,59 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_finetuned: False + load_pretrained: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moe_position: "pre" # post (position to insert PromptMoE Part) + embed_extract: "blip2_pretrain" # t5, random (way to extract embeddings of task instruction if moe_position is pre) + repeat_to_init_qt_candidates: True + num_qt_candidates: 20 + moe_topk: 2 + eval_gate_save: False + train_gate_save: False + gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_train_qf_train_qt_linear_gate_textt5_20ex_3loss_textinqf_epo3_1012/" + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/blip2/blip2_instruct_flant5xxl_qformer_moe.yaml b/minigpt4/configs/models/blip2/blip2_instruct_flant5xxl_qformer_moe.yaml new file mode 100644 index 0000000..9287f53 --- /dev/null +++ b/minigpt4/configs/models/blip2/blip2_instruct_flant5xxl_qformer_moe.yaml @@ -0,0 +1,56 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_finetuned: False + load_pretrained: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moebert_expert_num: 5 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.1 + moe_topk: 2 + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/blip2/blip2_instruct_vicuna7b.yaml b/minigpt4/configs/models/blip2/blip2_instruct_vicuna7b.yaml new file mode 100644 index 0000000..106a2c6 --- /dev/null +++ b/minigpt4/configs/models/blip2/blip2_instruct_vicuna7b.yaml @@ -0,0 +1,44 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: instruct_vicuna7b + load_finetuned: False + load_pretrained: True + + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/instruct_blip_vicuna7b_trimmed/instruct_blip_vicuna7b_trimmed.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + # path to Vicuna checkpoint + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + + # generation configs + prompt: "" + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + diff --git a/minigpt4/configs/models/blip2/blip2_pretrain.yaml b/minigpt4/configs/models/blip2/blip2_pretrain.yaml new file mode 100644 index 0000000..9fb4b0b --- /dev/null +++ b/minigpt4/configs/models/blip2/blip2_pretrain.yaml @@ -0,0 +1,42 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: pretrain + load_finetuned: False + + # pretrained: "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_pretrained/blip2_pretrained.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + # moe + use_moeqformer: True + use_route_moe: True + moebert_expert_num: 5 + moebert_num_beams: 2 + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b.yaml b/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b.yaml new file mode 100644 index 0000000..3a0cab0 --- /dev/null +++ b/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b.yaml @@ -0,0 +1,43 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: instruct_vicuna7b + load_finetuned: False + load_pretrained: True + + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + freeze_vit: True + + # Q-Former + num_query_token: 32 + + # path to Vicuna checkpoint + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + + # generation configs + prompt: "" + + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b_route_moe.yaml b/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b_route_moe.yaml new file mode 100644 index 0000000..0f032ff --- /dev/null +++ b/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b_route_moe.yaml @@ -0,0 +1,59 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + load_finetuned: False + load_pretrained: True + + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # path to Vicuna checkpoint + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'route_moe' + moebert_route_method: "post-route" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b_route_moe_universal.yaml b/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b_route_moe_universal.yaml new file mode 100644 index 0000000..19fd437 --- /dev/null +++ b/minigpt4/configs/models/blip2/blip2_pretrain_vicuna7b_route_moe_universal.yaml @@ -0,0 +1,59 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + load_finetuned: False + load_pretrained: True + + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # path to Vicuna checkpoint + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'uni_route_moe' + moebert_route_method: "post-route-uni" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "in" + +preprocess: + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" diff --git a/minigpt4/configs/models/blip2/blip2_qformer_moe_post_vicuna7b.yaml b/minigpt4/configs/models/blip2/blip2_qformer_moe_post_vicuna7b.yaml new file mode 100644 index 0000000..bf659f9 --- /dev/null +++ b/minigpt4/configs/models/blip2/blip2_qformer_moe_post_vicuna7b.yaml @@ -0,0 +1,60 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + load_finetuned: False + load_pretrained: True + + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/instruct_blip_vicuna7b_trimmed/instruct_blip_vicuna7b_trimmed.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # path to Vicuna checkpoint + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: "naive_moe" + moebert_expert_num: 5 + moebert_route_method: "gate-sentence-post" + moebert_load_balance: 0 + moe_topk: 1 + use_balance_loss: False + moe_weight_type: 'average' + + +preprocess: + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + diff --git a/minigpt4/configs/models/minigpt4_llama2.yaml b/minigpt4/configs/models/minigpt/minigpt4_llama2.yaml similarity index 100% rename from minigpt4/configs/models/minigpt4_llama2.yaml rename to minigpt4/configs/models/minigpt/minigpt4_llama2.yaml diff --git a/minigpt4/configs/models/minigpt4_vicuna0.yaml b/minigpt4/configs/models/minigpt/minigpt4_vicuna0.yaml similarity index 88% rename from minigpt4/configs/models/minigpt4_vicuna0.yaml rename to minigpt4/configs/models/minigpt/minigpt4_vicuna0.yaml index 718054c..f686185 100644 --- a/minigpt4/configs/models/minigpt4_vicuna0.yaml +++ b/minigpt4/configs/models/minigpt/minigpt4_vicuna0.yaml @@ -15,7 +15,7 @@ model: # generation configs prompt: "" - llama_model: "please set this value to the path of vicuna model" + llama_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" preprocess: vis_processor: diff --git a/minigpt4/configs/models/minigpt_v2.yaml b/minigpt4/configs/models/minigpt/minigpt_v2.yaml similarity index 87% rename from minigpt4/configs/models/minigpt_v2.yaml rename to minigpt4/configs/models/minigpt/minigpt_v2.yaml index 1d85d20..cb62ec1 100755 --- a/minigpt4/configs/models/minigpt_v2.yaml +++ b/minigpt4/configs/models/minigpt/minigpt_v2.yaml @@ -11,7 +11,7 @@ model: # generation configs prompt: "" - llama_model: "please set this value to the path of llama2-chat-7b" + llama_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/llama_2_7b_chat" lora_r: 64 lora_alpha: 16 diff --git a/minigpt4/datasets/builders/base_dataset_builder.py b/minigpt4/datasets/builders/base_dataset_builder.py index 4b607e3..79169a8 100644 --- a/minigpt4/datasets/builders/base_dataset_builder.py +++ b/minigpt4/datasets/builders/base_dataset_builder.py @@ -208,7 +208,8 @@ class BaseDatasetBuilder: ann_paths = abs_ann_paths # visual data storage path - vis_path = os.path.join(vis_info.storage, split) + # vis_path = os.path.join(vis_info.storage, split) + vis_path = os.path.join(vis_info.storage) if not os.path.isabs(vis_path): # vis_path = os.path.join(utils.get_cache_path(), vis_path) @@ -219,12 +220,14 @@ class BaseDatasetBuilder: # create datasets dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls + print(dataset_cls) datasets[split] = dataset_cls( vis_processor=vis_processor, text_processor=text_processor, ann_paths=ann_paths, vis_root=vis_path, ) + print("{} Length {} : {}".format(dataset_cls.__name__, split, len(datasets[split]))) # print class name return datasets diff --git a/minigpt4/datasets/builders/image_text_pair_builder.py b/minigpt4/datasets/builders/image_text_pair_builder.py index fb344f1..aec5170 100644 --- a/minigpt4/datasets/builders/image_text_pair_builder.py +++ b/minigpt4/datasets/builders/image_text_pair_builder.py @@ -6,19 +6,18 @@ from minigpt4.common.registry import registry from minigpt4.datasets.builders.base_dataset_builder import BaseDatasetBuilder from minigpt4.datasets.datasets.laion_dataset import LaionDataset from minigpt4.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset -from minigpt4.datasets.datasets.text_caps import TextCapDataset -from minigpt4.datasets.datasets.llava_dataset import LlavaDetailDataset, LlavaReasonDataset, LlavaConversationDataset +from minigpt4.datasets.datasets.text_caps import TextCapDataset, TextCapEvalDataset +from minigpt4.datasets.datasets.text_vqa_dataset import TextVQADataset, TextVQAEvalDataset +from minigpt4.datasets.datasets.llava_dataset import LlavaDetailDataset, LlavaReasonDataset, LlavaConversationDataset, LlavaMixDataset, LlavaPretrainDataset from minigpt4.datasets.datasets.unnatural_instruction import UnnaturalDataset from minigpt4.datasets.datasets.multitask_conversation import MultiTaskConversationDataset from minigpt4.datasets.datasets.flickr import GroundedDetailDataset,CaptionToObjectDataset,PhraseToObjectDataset -from minigpt4.datasets.datasets.vg_dataset import ReferVisualGenomeDataset -from minigpt4.datasets.datasets.coco_dataset import ReferCOCODataset, InvReferCOCODataset -from minigpt4.datasets.datasets.gqa_datasets import GQADataset -from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset -from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset +from minigpt4.datasets.datasets.gqa_datasets import GQADataset, GQAEvalDataset +from minigpt4.datasets.datasets.aok_vqa_datasets import AOKVQADataset, AOKVQAEvalDataset +from minigpt4.datasets.datasets.coco_vqa_datasets import COCOVQADataset, COCOVQAEvalDataset +from minigpt4.datasets.datasets.ok_vqa_datasets import OKVQADataset, OKVQAEvalDataset from minigpt4.datasets.datasets.ocrvqa_dataset import OCRVQADataset -from minigpt4.datasets.datasets.coco_caption import COCOCapDataset - +from minigpt4.datasets.datasets.coco_caption import COCOCapDataset, COCOCapEvalDataset @registry.register_builder("multitask_conversation") class MultitaskConversationBuilder(BaseDatasetBuilder): @@ -29,7 +28,7 @@ class MultitaskConversationBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[multitask_conversation]: Building datasets...") self.build_processors() build_info = self.config.build_info datasets = dict() @@ -55,7 +54,7 @@ class UnnaturalInstructionBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[unnatural_instruction]: Building datasets...") self.build_processors() build_info = self.config.build_info datasets = dict() @@ -66,6 +65,7 @@ class UnnaturalInstructionBuilder(BaseDatasetBuilder): text_processor=self.text_processors["train"], ann_path=build_info.ann_path, ) + print("{} Length: {}".format(dataset_cls.__name__, len(datasets['train']))) # print class name return datasets @@ -80,7 +80,7 @@ class LlavaDetailBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[llava_detail]: Building datasets...") self.build_processors() build_info = self.config.build_info datasets = dict() @@ -93,11 +93,10 @@ class LlavaDetailBuilder(BaseDatasetBuilder): ann_path=build_info.ann_path, vis_root=build_info.image_path, ) + print("{} Length: {}".format(dataset_cls.__name__, len(datasets['train']))) # print class name return datasets - - @registry.register_builder("llava_reason") class LlavaReasonBuilder(BaseDatasetBuilder): train_dataset_cls = LlavaReasonDataset @@ -107,7 +106,7 @@ class LlavaReasonBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[llava_reason]: Building datasets...") self.build_processors() build_info = self.config.build_info datasets = dict() @@ -120,9 +119,37 @@ class LlavaReasonBuilder(BaseDatasetBuilder): ann_path=build_info.ann_path, vis_root=build_info.image_path, ) + print("{} Length: {}".format(dataset_cls.__name__, len(datasets['train']))) # print class name return datasets +@registry.register_builder("llava_pretrain") +class LlavaPretrainBuilder(BaseDatasetBuilder): + train_dataset_cls = LlavaPretrainDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/llava/pretrain_cap.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("[llava_pretrain]: Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_path=build_info.ann_path, + vis_root=build_info.image_path, + ) + print("{} Length: {}".format(dataset_cls.__name__, len(datasets['train']))) # print class name + + return datasets + + @registry.register_builder("llava_conversation") class LlavaReasonBuilder(BaseDatasetBuilder): train_dataset_cls = LlavaConversationDataset @@ -132,7 +159,7 @@ class LlavaReasonBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[llava_conversation]: Building datasets...") self.build_processors() build_info = self.config.build_info datasets = dict() @@ -145,15 +172,58 @@ class LlavaReasonBuilder(BaseDatasetBuilder): ann_path=build_info.ann_path, vis_root=build_info.image_path, ) + print("{} Length: {}".format(dataset_cls.__name__, len(datasets['train']))) # print class name return datasets +@registry.register_builder("llava_mix") +class LlavaMixBuilder(BaseDatasetBuilder): + train_dataset_cls = LlavaMixDataset + DATASET_CONFIG_DICT = { + "default": "configs/datasets/llava/mix.yaml", + "mix_coco_gqa": "configs/datasets/mix_vqa/mix_vqa.yaml", + } + + def build_datasets(self): + # at this point, all the annotations and image/videos should be all downloaded to the specified locations. + logging.info("[llava_mix]: Building datasets...") + self.build_processors() + build_info = self.config.build_info + datasets = dict() + + # create datasets + dataset_cls = self.train_dataset_cls + vis_roots = { + 'coco':build_info.image_path_coco, + 'gqa':build_info.image_path_gqa, + 'ocr':build_info.image_path_ocr, + 'text':build_info.image_path_text, + # 'vg':build_info.image_path_vg, + } + datasets['train'] = dataset_cls( + vis_processor=self.vis_processors["train"], + text_processor=self.text_processors["train"], + ann_path=build_info.ann_path, + vis_root=vis_roots, + ) + print("{} Length: {}".format(dataset_cls.__name__, len(datasets['train']))) # print class name + + # vis_roots = { + # 'coco':'/mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO/train2014', + # 'gqa':'/mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images', + # 'ocr':'/mnt/pfs-guan-ssai/nlu/wanghanzi/data/OCRVQA/images', + # 'text':'/mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextVQA/train_images', + # # 'vg':build_info.image_path_vg, + # } + + return datasets + class AllRefCOCOBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[AllRefCOCOBuilder]: Building datasets...") self.build_processors() build_info = self.config.build_info @@ -181,81 +251,10 @@ class AllRefCOCOBuilder(BaseDatasetBuilder): return datasets -@registry.register_builder("refcoco") -class RefCOCOBuilder(AllRefCOCOBuilder): - train_dataset_cls = ReferCOCODataset - DATASET_CONFIG_DICT = { - "default": "configs/datasets/coco_bbox/refcoco.yaml", - } - -@registry.register_builder("refcocop") -class RefCOCOPBuilder(AllRefCOCOBuilder): - train_dataset_cls = ReferCOCODataset - DATASET_CONFIG_DICT = { - "default": "configs/datasets/coco_bbox/refcocop.yaml", - } - - -@registry.register_builder("refcocog") -class RefCOCOGBuilder(AllRefCOCOBuilder): - train_dataset_cls = ReferCOCODataset - DATASET_CONFIG_DICT = { - "default": "configs/datasets/coco_bbox/refcocog.yaml", - } - -@registry.register_builder("invrefcoco") -class RefCOCOBuilder(AllRefCOCOBuilder): - train_dataset_cls = InvReferCOCODataset - DATASET_CONFIG_DICT = { - "default": "configs/datasets/coco_bbox/invrefcoco.yaml", - } - - -@registry.register_builder("invrefcocop") -class RefCOCOPBuilder(AllRefCOCOBuilder): - train_dataset_cls = InvReferCOCODataset - DATASET_CONFIG_DICT = { - "default": "configs/datasets/coco_bbox/invrefcocop.yaml", - } - - -@registry.register_builder("invrefcocog") -class RefCOCOGBuilder(AllRefCOCOBuilder): - train_dataset_cls = InvReferCOCODataset - DATASET_CONFIG_DICT = { - "default": "configs/datasets/coco_bbox/invrefcocog.yaml", - } - -@registry.register_builder("refvg") -class RefVisualGenomeBuilder(BaseDatasetBuilder): - train_dataset_cls = ReferVisualGenomeDataset - DATASET_CONFIG_DICT = { - "default": "configs/datasets/vg/ref.yaml", - } - - def build_datasets(self): - # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") - self.build_processors() - - build_info = self.config.build_info - data_dir = build_info.data_dir - datasets = dict() - - # create datasets - dataset_cls = self.train_dataset_cls - datasets['train'] = dataset_cls( - vis_processor=self.vis_processors["train"], - text_processor=self.text_processors["train"], - data_dir=data_dir, - ) - - return datasets - - @registry.register_builder("textcaps_caption") class TextcapCaptionBuilder(BaseDatasetBuilder): train_dataset_cls = TextCapDataset + eval_dataset_cls = TextCapEvalDataset DATASET_CONFIG_DICT = {"default": "configs/datasets/textcaps/caption.yaml"} @@ -265,44 +264,45 @@ class TextcapCaptionBuilder(BaseDatasetBuilder): def _download_vis(self): pass - def build(self): - self.build_processors() +@registry.register_builder("text_vqa") +class TextVQABuilder(BaseDatasetBuilder): + train_dataset_cls = TextVQADataset + eval_dataset_cls = TextVQAEvalDataset - build_info = self.config.build_info - - datasets = dict() - split = "train" - - # create datasets - # [NOTE] return inner_datasets (wds.DataPipeline) - dataset_cls = self.train_dataset_cls - datasets[split] = dataset_cls( - vis_processor=self.vis_processors[split], - text_processor=self.text_processors[split], - ann_path=build_info.ann_path, - vis_root=build_info.image_path, - ) - - return datasets + DATASET_CONFIG_DICT = {"default": "configs/datasets/textvqa/vqa.yaml"} + def _download_ann(self): + pass + + def _download_vis(self): + pass + @registry.register_builder("coco_vqa") class COCOVQABuilder(BaseDatasetBuilder): train_dataset_cls = COCOVQADataset + eval_dataset_cls = COCOVQAEvalDataset DATASET_CONFIG_DICT = { "default": "configs/datasets/coco/defaults_vqa.yaml", + "vqa_v2_eval": "configs/datasets/coco/defaults_vqa_eval.yaml", + "vqa_v2_part": "configs/datasets/coco/defaults_vqa_part.yaml", } @registry.register_builder("ok_vqa") class OKVQABuilder(COCOVQABuilder): + train_dataset_cls = OKVQADataset + eval_dataset_cls = OKVQAEvalDataset + DATASET_CONFIG_DICT = { "default": "configs/datasets/okvqa/defaults.yaml", + "ok_vqa_eval": "configs/datasets/okvqa/eval.yaml", } @registry.register_builder("aok_vqa") class AOKVQABuilder(BaseDatasetBuilder): train_dataset_cls = AOKVQADataset + eval_dataset_cls = AOKVQAEvalDataset DATASET_CONFIG_DICT = {"default": "configs/datasets/aokvqa/defaults.yaml"} @@ -310,13 +310,15 @@ class AOKVQABuilder(BaseDatasetBuilder): @registry.register_builder("gqa") class GQABuilder(BaseDatasetBuilder): train_dataset_cls = GQADataset + eval_dataset_cls = GQAEvalDataset + DATASET_CONFIG_DICT = { - "default": "configs/datasets/gqa/balanced_val.yaml", + "balanced_sft_raw": "configs/datasets/gqa/balanced_sft_raw.yaml", + "balanced_sft_raw_eval":"configs/datasets/gqa/balanced_sft_raw_eval.yaml", + "balanced_sft_raw_part":"configs/datasets/gqa/balanced_sft_raw_part.yaml", } - - @registry.register_builder("flickr_grounded_caption") class GroundedCaptionBuilder(BaseDatasetBuilder): train_dataset_cls = GroundedDetailDataset @@ -326,7 +328,7 @@ class GroundedCaptionBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[flickr_grounded_caption]: Building datasets...") self.build_processors() build_info = self.config.build_info datasets = dict() @@ -352,7 +354,7 @@ class CaptionToPhraseBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[flickr_CaptionToPhrase]: Building datasets...") self.build_processors() build_info = self.config.build_info datasets = dict() @@ -377,7 +379,7 @@ class CaptionToPhraseBuilder(BaseDatasetBuilder): def build_datasets(self): # at this point, all the annotations and image/videos should be all downloaded to the specified locations. - logging.info("Building datasets...") + logging.info("[flickr_ObjectToPhrase]: Building datasets...") self.build_processors() build_info = self.config.build_info datasets = dict() @@ -394,8 +396,6 @@ class CaptionToPhraseBuilder(BaseDatasetBuilder): return datasets - - class DocumentVQABuilder(BaseDatasetBuilder): def _download_ann(self): pass @@ -417,7 +417,8 @@ class DocumentVQABuilder(BaseDatasetBuilder): vis_root=build_info.image_path, ann_path=build_info.ann_path ) - + print("{} Length: {}".format(dataset_cls.__name__, len(datasets['train']))) # print class name + return datasets @@ -495,9 +496,11 @@ class LaionBuilder(BaseDatasetBuilder): @registry.register_builder("coco_caption") class COCOCapBuilder(BaseDatasetBuilder): train_dataset_cls = COCOCapDataset - + eval_dataset_cls = COCOCapEvalDataset + DATASET_CONFIG_DICT = { "default": "configs/datasets/coco/caption.yaml", + "coco_cap_eval": "configs/datasets/coco/caption_eval.yaml", } diff --git a/minigpt4/datasets/datasets/aok_vqa_datasets.py b/minigpt4/datasets/datasets/aok_vqa_datasets.py index 00ed06d..08f06d9 100755 --- a/minigpt4/datasets/datasets/aok_vqa_datasets.py +++ b/minigpt4/datasets/datasets/aok_vqa_datasets.py @@ -13,7 +13,7 @@ import torch from PIL import Image -from minigpt4.datasets.datasets.vqa_datasets import VQADataset #, VQAEvalDataset +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset class __DisplMixin: @@ -36,81 +36,192 @@ class AOKVQADataset(VQADataset, __DisplMixin): def __init__(self, vis_processor, text_processor, vis_root, ann_paths): super().__init__(vis_processor, text_processor, vis_root, ann_paths) - self.instruction_pool =[ - "[vqa] {}", - "[vqa] Based on the image, respond to this question with a short answer: {}" + self.instruction_pool =[ + '{} Choose from {}.', + 'Q: {} Multi Choices: {} A: ', + 'Question: {} Multi Choices: {} Answer: ', + "{} Choose one from the following possible answers: {}. ", + '{} Choose from {}. The answer is', ] exist_annotation = [] for ann in self.annotation: - image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + # image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image_path = os.path.join(self.vis_root, ann["image"]) if os.path.exists(image_path): exist_annotation.append(ann) self.annotation = exist_annotation + self.source = 'aokvqa' def get_data(self, index): ann = self.annotation[index] - image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + # image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image_path = os.path.join(self.vis_root, ann["image"]) image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) question = self.text_processor(ann["question"]) - answer_key = "direct_answers" - - answer_weight = {} - for answer in ann[answer_key]: - if answer in answer_weight.keys(): - answer_weight[answer] += 1 / len(ann[answer_key]) - else: - answer_weight[answer] = 1 / len(ann[answer_key]) - - answers = list(answer_weight.keys()) - weights = list(answer_weight.values()) - - answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + answer_lst = ann["choices"] + direct_answers = ann["direct_answers"] + final_answer = random.choices(direct_answers, k=1)[0] + for answer in answer_lst: + if answer in direct_answers: + final_answer = answer return { "image": image, + "image_id": ann["image"], "question": question, - "answer": answer, + "answer": final_answer, + "choices": ", ".join(answer_lst) } def __getitem__(self, index): data = self.get_data(index) question = self.text_processor(data["question"]) - instruction = random.choice(self.instruction_pool).format(question) - instruction = " {} ".format(instruction) answer = self.text_processor(data['answer']) + q_input = question + llm_input = random.choice(self.instruction_pool).format(question, data["choices"]) return { "image": data['image'], - "instruction_input": instruction, + "image_id": data["image_id"], + # "q_input": q_input, + "q_input": llm_input, + "llm_input": llm_input, + "text_input": question, + "text_output": answer, "answer": answer, + "source": 'aokvqa', } -class AOKVQGDataset(AOKVQADataset): - +class AOKVQAEvalDataset(VQAEvalDataset, __DisplMixin): def __init__(self, vis_processor, text_processor, vis_root, ann_paths): - super().__init__(vis_processor, text_processor, vis_root, ann_paths) - self.instruction_pool = [ - 'Given the image, generate a question whose answer is: {}', - 'Based on the image, provide a question with the answer: {}', - 'Given the visual representation, create a question for which the answer is "{}"', - 'From the image provided, craft a question that leads to the reply: {}', - 'Considering the picture, come up with a question where the answer is: {}', - 'Taking the image into account, generate an question that has the answer: {}' - ] + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ - def __getitem__(self, index): - data = self.get_data(index) - instruction = random.choice(self.instruction_pool).format(data['answer']) + self.vis_root = vis_root + + self.annotation = json.load(open(ann_paths[0])) + + self.instruction_pool =[ + '{} Choose from {}.', + 'Q: {} Multi Choices: {} A: ', + 'Question: {} Multi Choices: {} Answer: ', + "{} Choose one from the following possible answers: {}. ", + '{} Choose from {}. The answer is', + ] + + try: + self.coco_fmt_qust_file = ann_paths[2] + self.coco_fmt_anno_file = ann_paths[3] + except IndexError: + self.coco_fmt_qust_file = None + self.coco_fmt_anno_file = None + + self.vis_processor = vis_processor + self.text_processor = text_processor + self.source = 'aokvqa' + self.annotation_add = self.get_data() + + def collater(self, samples): + ( + image_list, + question_list, + question_id_list, + choices_list, + correct_choice_idx_list, + direct_answers_list, + llm_input_list, + q_input_list, + gt_answers_list, + source_list, + ) = ([], [], [], [], [], [], [], [], [], []) + + for sample in samples: + image_list.append(sample["image"]) + question_list.append(sample["text_input"]) + question_id_list.append(sample["question_id"]) + choices_list.append(sample["choices"]) + correct_choice_idx_list.append(sample["correct_choice_idx"]) + direct_answers_list.append(sample["direct_answers"]) + llm_input_list.append(sample["llm_input"]) + q_input_list.append(sample["q_input"]) + gt_answers_list.append(sample["gt_answers"]) + source_list.append(sample["source"]) return { - "image": data['image'], - "instruction_input": instruction, - "answer": data['question'], + "image": torch.stack(image_list, dim=0), + "text_input": question_list, + "question_id": question_id_list, + "choices": choices_list, + "correct_choice_idx": correct_choice_idx_list, + "direct_answers": direct_answers_list, + "llm_input": llm_input_list, + "q_input": llm_input_list, + # "q_input": q_input_list, + "gt_answers": gt_answers_list, + "source": source_list, } + + def get_data(self): + import numpy as np + ann_instruct = list() + for i in range(len(self.annotation)): + ann = self.annotation[i].copy() + j = i % len(self.instruction_pool) + question = self.text_processor(ann["question"]) + choices = ann["choices"] + llm_input = self.instruction_pool[j].format(question, ", ".join(choices)) + ann['llm_input'] = llm_input + ann_instruct.append(ann) + np.random.seed(10) + np.random.shuffle(ann_instruct) + return ann_instruct + + def __getitem__(self, index): + # ann = self.annotation[index] + ann = self.annotation_add[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + if "direct_answers" in ann: + direct_answers = ann["direct_answers"] + else: + direct_answers = None + + choices = ann["choices"] + if "correct_choice_idx" in ann: + correct_choice_idx = ann["correct_choice_idx"] + correct_answer = choices[correct_choice_idx] + else: + correct_choice_idx = None + correct_answer = direct_answers + + llm_input = ann.get("llm_input",random.choice(self.instruction_pool).format(question)) + # llm_input = random.choice(self.instruction_pool).format(question, ", ".join(choices)) + + return { + "image": image, + # "q_input": question, + "q_input": llm_input, + "llm_input": llm_input, + "text_input": question, + "question_id": ann["question_id"], + "choices": choices, + "correct_choice_idx": correct_choice_idx, + "gt_answers": correct_answer, + "direct_answers": direct_answers, + "source": 'aokvqa', + } + \ No newline at end of file diff --git a/minigpt4/datasets/datasets/base_dataset.py b/minigpt4/datasets/datasets/base_dataset.py index 97aed82..95ba76f 100644 --- a/minigpt4/datasets/datasets/base_dataset.py +++ b/minigpt4/datasets/datasets/base_dataset.py @@ -30,7 +30,10 @@ class BaseDataset(Dataset): # print("ann_path", ann_path) ann = json.load(open(ann_path, "r")) if isinstance(ann, dict): - self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) + if 'annotations' in ann.keys(): + self.annotation.extend(json.load(open(ann_path, "r"))['annotations']) + elif 'data' in ann.keys(): + self.annotation.extend(json.load(open(ann_path, "r"))['data']) # self.annotation.extend(json.load(open(ann_path, "r"))) else: self.annotation.extend(json.load(open(ann_path, "r"))) diff --git a/minigpt4/datasets/datasets/caption_datasets.py b/minigpt4/datasets/datasets/caption_datasets.py index 6a432a7..9354adc 100644 --- a/minigpt4/datasets/datasets/caption_datasets.py +++ b/minigpt4/datasets/datasets/caption_datasets.py @@ -59,73 +59,7 @@ class CaptionDataset(BaseDataset, __DisplMixin): "text_input": caption, "image_id": self.img_ids[ann["image_id"]], } - - - -class COCOCaptionDataset(BaseDataset, __DisplMixin): - def __init__(self, vis_processor, text_processor, vis_root, ann_paths): - """ - vis_root (string): Root directory of images (e.g. coco/images/) - ann_root (string): directory to store the annotation file - """ - super().__init__(vis_processor, text_processor, vis_root, ann_paths) - - self.img_ids = {} - n = 0 - - self.filter_anntation = [] - - for ann in self.annotation: - if "train" in ann["image"]: - self.filter_anntation.append(ann) - self.annotation = self.filter_anntation - - for ann in self.annotation: - img_id = ann["image_id"] - if img_id not in self.img_ids.keys(): - self.img_ids[img_id] = n - n += 1 - - self.instruction_pool = [ - 'Briefly describe this image.', - 'Provide a concise depiction of this image.', - 'Present a short description of this image.', - 'Summarize this image in a few words.', - 'A short image caption:', - 'A short image description:', - 'A photo of ', - 'An image that shows ', - 'Write a short description for the image. ', - 'Write a description for the photo.', - 'Provide a description of what is presented in the photo.', - 'Briefly describe the content of the image.', - 'Can you briefly explain what you see in the image?', - 'Could you use a few words to describe what you perceive in the photo?', - 'Please provide a short depiction of the picture.', - 'Using language, provide a short account of the image.', - 'Use a few words to illustrate what is happening in the picture.', - ] - def __getitem__(self, index): - - # TODO this assumes image input, not general enough - ann = self.annotation[index] - - img_file = ann["image"].split("/")[-1] - image_path = os.path.join(self.vis_root, img_file) - image = Image.open(image_path).convert("RGB") - - image = self.vis_processor(image) - caption = self.text_processor(ann["caption"]) - - instruction = random.choice(self.instruction_pool) - instruction = " [caption] {} ".format(instruction) - - return { - "image": image, - "answer": caption, - "instruction_input": instruction, - } - + class CaptionEvalDataset(BaseDataset, __DisplMixin): def __init__(self, vis_processor, text_processor, vis_root, ann_paths): """ @@ -141,7 +75,7 @@ class CaptionEvalDataset(BaseDataset, __DisplMixin): image_path = os.path.join(self.vis_root, ann["image"]) image = Image.open(image_path).convert("RGB") - + image = self.vis_processor(image) return { @@ -149,3 +83,4 @@ class CaptionEvalDataset(BaseDataset, __DisplMixin): "image_id": ann["image_id"], "instance_id": ann["instance_id"], } + diff --git a/minigpt4/datasets/datasets/coco_caption.py b/minigpt4/datasets/datasets/coco_caption.py index 76f86e4..c097034 100755 --- a/minigpt4/datasets/datasets/coco_caption.py +++ b/minigpt4/datasets/datasets/coco_caption.py @@ -9,18 +9,102 @@ import os import json import torch import numpy as np +import random from PIL import Image from PIL import ImageFile +from collections import OrderedDict ImageFile.LOAD_TRUNCATED_IMAGES = True -from minigpt4.datasets.datasets.caption_datasets import COCOCaptionDataset, CaptionEvalDataset +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset, CaptionEvalDataset -COCOCapDataset = COCOCaptionDataset +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + return OrderedDict( + { + "file": ann["image"], + "caption": ann["caption"], + "image": sample["image"], + } + ) + +class COCOCapDataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.img_ids = {} + n = 0 + self.filter_anntation = [] + + for ann in self.annotation: + if "train" in ann["image"]: + self.filter_anntation.append(ann) + self.annotation = self.filter_anntation + + for ann in self.annotation: + img_id = ann["image_id"] + if img_id not in self.img_ids.keys(): + self.img_ids[img_id] = n + n += 1 + + self.instruction_pool = [ + 'Briefly describe this image.', + 'Provide a concise depiction of this image.', + 'Present a short description of this image.', + 'Summarize this image in a few words.', + 'A short image caption:', + 'A short image description:', + 'A photo of ', + 'An image that shows ', + 'Write a short description for the image. ', + 'Write a description for the photo.', + 'Provide a description of what is presented in the photo.', + 'Briefly describe the content of the image.', + 'Can you briefly explain what you see in the image?', + 'Could you use a few words to describe what you perceive in the photo?', + 'Please provide a short depiction of the picture.', + 'Using language, provide a short account of the image.', + 'Use a few words to illustrate what is happening in the picture.', + ] + self.source = 'coco_cap' + + def __getitem__(self, index): + + # TODO this assumes image input, not general enough + ann = self.annotation[index] + + # img_file = ann["image"].split("/")[-1] + img_file = ann["image"] + image_path = os.path.join(self.vis_root, img_file) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + caption = self.text_processor(ann["caption"]) + + instruction = random.choice(self.instruction_pool) + # q_input = "" + q_input = instruction + llm_input = instruction + + return { + "image": image, + "image_id": ann["image"], + "answer": caption, + "q_input": q_input, + "llm_input": llm_input, + "text_input": llm_input, + "text_output": caption, + "source": 'coco_cap', + } class COCOCapEvalDataset(CaptionEvalDataset): @@ -32,20 +116,51 @@ class COCOCapEvalDataset(CaptionEvalDataset): """ super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.instruction_pool = [ + 'Briefly describe this image.', + 'Provide a concise depiction of this image.', + 'Present a short description of this image.', + 'Summarize this image in a few words.', + 'A short image caption:', + 'A short image description:', + 'A photo of ', + 'An image that shows ', + 'Write a short description for the image. ', + 'Write a description for the photo.', + 'Provide a description of what is presented in the photo.', + 'Briefly describe the content of the image.', + 'Can you briefly explain what you see in the image?', + 'Could you use a few words to describe what you perceive in the photo?', + 'Please provide a short depiction of the picture.', + 'Using language, provide a short account of the image.', + 'Use a few words to illustrate what is happening in the picture.', + ] + self.source = 'coco_cap' + def __getitem__(self, index): ann = self.annotation[index] image_path = os.path.join(self.vis_root, ann["image"]) image = Image.open(image_path).convert("RGB") - - image = self.vis_processor(image) + try: + image = self.vis_processor(image) + except Exception as e: + print(e) + print(image_path) img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[-1] + instruction = random.choice(self.instruction_pool) + # q_input = "" + q_input = instruction + llm_input = instruction return { "image": image, "image_id": img_id, - "instance_id": ann["instance_id"], + "text_input":llm_input, + "q_input": q_input, + "llm_input": llm_input, + "source": self.source, } @@ -75,25 +190,6 @@ class NoCapsEvalDataset(CaptionEvalDataset): } -class RefCOCOEvalData(torch.utils.data.Dataset): - def __init__(self, loaded_data, vis_processor, root_path): - self.loaded_data = loaded_data - self.root_path = root_path - self.vis_processor = vis_processor - - def __len__(self): - return len(self.loaded_data) - - def __getitem__(self, idx): - data = self.loaded_data[idx] - img_id = data['img_id'] - sent = data['sents'] - image_path = os.path.join(self.root_path, f'{img_id[:27]}.jpg') - image = Image.open(image_path).convert('RGB') - image = self.vis_processor(image) - question = f"[refer] tell me the location of {sent}?" - return image, question, img_id - class EvalCaptionData(torch.utils.data.Dataset): def __init__(self, loaded_data, vis_processor, root_path): self.loaded_data = loaded_data diff --git a/minigpt4/datasets/datasets/coco_vqa_datasets.py b/minigpt4/datasets/datasets/coco_vqa_datasets.py index 2dbe056..53fd5d3 100755 --- a/minigpt4/datasets/datasets/coco_vqa_datasets.py +++ b/minigpt4/datasets/datasets/coco_vqa_datasets.py @@ -8,7 +8,7 @@ import os import json import random - +import numpy as np from PIL import Image from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset @@ -35,23 +35,28 @@ class COCOVQADataset(VQADataset, __DisplMixin): def __init__(self, vis_processor, text_processor, vis_root, ann_paths): super().__init__(vis_processor, text_processor, vis_root, ann_paths) - self.instruction_pool =[ - "[vqa] {}", - "[vqa] Based on the image, respond to this question with a short answer: {}" + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + '{} A short answer to the question is ', + 'Question: {} Short answer:', ] - exist_annotation = [] for ann in self.annotation: - image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + # image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image_path = os.path.join(self.vis_root, ann["image"]) if os.path.exists(image_path): exist_annotation.append(ann) self.annotation = exist_annotation + self.source = 'vqav2' def get_data(self, index): ann = self.annotation[index] - image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + # image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image_path = os.path.join(self.vis_root, ann["image"]) image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) @@ -70,9 +75,9 @@ class COCOVQADataset(VQADataset, __DisplMixin): answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights - return { "image": image, + "image_id": ann["image"], "question": question, "question_id": question_id, "answer": answer, @@ -80,14 +85,24 @@ class COCOVQADataset(VQADataset, __DisplMixin): def __getitem__(self, index): data = self.get_data(index) - instruction = random.choice(self.instruction_pool).format(data['question']) - instruction = " {} ".format(instruction) + question = data['question'] + # instruction = random.choice(self.instruction_pool).format(question) + # instruction = " {} ".format(instruction) + answer = self.text_processor(data['answer']) + q_input = question + llm_input = random.choice(self.instruction_pool).format(question) return { "image": data['image'], + "image_id": data["image_id"], "question_id": data["question_id"], - "instruction_input": instruction, - "answer": self.text_processor(data['answer']), + # "q_input": q_input, + "q_input": llm_input, + "llm_input": llm_input, + "text_input": question, + "text_output": answer, + "answer": answer, + "source": 'vqav2', } @@ -98,12 +113,22 @@ class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin): ann_root (string): directory to store the annotation file """ - self.instruction_pool = [ + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + '{} A short answer to the question is ', 'Question: {} Short answer:', ] self.vis_root = vis_root self.annotation = json.load(open(ann_paths[0])) + exist_annotation = [] + for ann in self.annotation: + image_path = os.path.join(self.vis_root, ann["image"]) + if os.path.exists(image_path): + exist_annotation.append(ann) + self.annotation = exist_annotation answer_list_path = ann_paths[1] if os.path.exists(answer_list_path): @@ -121,25 +146,45 @@ class COCOVQAEvalDataset(VQAEvalDataset, __DisplMixin): self.vis_processor = vis_processor self.text_processor = text_processor + self.source = 'vqav2' + self.annotation_add = self.get_data() self._add_instance_ids() + def get_data(self): + ann_instruct = list() + for i in range(len(self.annotation)): + ann = self.annotation[i].copy() + j = i % len(self.instruction_pool) + question = self.text_processor(ann["question"]) + llm_input = self.instruction_pool[j].format(question) + ann['llm_input'] = llm_input + ann_instruct.append(ann) + np.random.seed(10) + np.random.shuffle(ann_instruct) + return ann_instruct + def __getitem__(self, index): - ann = self.annotation[index] + ann = self.annotation_add[index] image_path = os.path.join(self.vis_root, ann["image"]) image = Image.open(image_path).convert("RGB") - image = self.vis_processor(image) + question = self.text_processor(ann["question"]) - - instruction = random.choice(self.instruction_pool).format(question) - instruction = " {} ".format(instruction) - + q_input = question + llm_input = ann.get("llm_input",random.choice(self.instruction_pool).format(question)) + return { "image": image, + "image_id": ann["image"], 'image_path': image_path, - "question": question, "question_id": ann["question_id"], - "instruction_input": instruction, - "instance_id": ann["instance_id"], + # "instance_id": ann["instance_id"], + # "question": question, + "q_input": llm_input, + "q_input": q_input, + "llm_input": llm_input, + "text_input": question, + # "answer": ann["answer"], + "source": 'vqav2', } diff --git a/minigpt4/datasets/datasets/dataloader_utils.py b/minigpt4/datasets/datasets/dataloader_utils.py index 8eaa3a5..08f64da 100644 --- a/minigpt4/datasets/datasets/dataloader_utils.py +++ b/minigpt4/datasets/datasets/dataloader_utils.py @@ -23,10 +23,10 @@ class MultiIterLoader: def __init__(self, loaders, ratios=None): # assert all loaders has __next__ method - for loader in loaders: - assert hasattr( - loader, "__next__" - ), "Loader {} has no __next__ method.".format(loader) + # for loader in loaders: + # assert hasattr( + # loader, "__next__" + # ), "Loader {} has no __next__ method.".format(loader) if ratios is None: ratios = [1.0] * len(loaders) @@ -42,6 +42,9 @@ class MultiIterLoader: loader_idx = random.choices(range(len(self.loaders)), self.ratios, k=1)[0] return next(self.loaders[loader_idx]) + def __len__(self): + return sum([len(loader) for loader in self.loaders]) + class PrefetchLoader(object): """ diff --git a/minigpt4/datasets/datasets/gqa_datasets.py b/minigpt4/datasets/datasets/gqa_datasets.py index b5e835a..22e98c9 100755 --- a/minigpt4/datasets/datasets/gqa_datasets.py +++ b/minigpt4/datasets/datasets/gqa_datasets.py @@ -10,10 +10,11 @@ import json from PIL import Image -from minigpt4.datasets.datasets.vqa_datasets import VQADataset +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset from collections import OrderedDict import random +import numpy as np class __DisplMixin: def displ_item(self, index): @@ -33,10 +34,23 @@ class __DisplMixin: class GQADataset(VQADataset, __DisplMixin): def __init__(self, vis_processor, text_processor, vis_root, ann_paths): super().__init__(vis_processor, text_processor, vis_root, ann_paths) - self.instruction_pool =[ - "[vqa] {}", - "[vqa] Based on the image, respond to this question with a short answer: {}" + # self.instruction_pool =[ + # "[vqa] {}", + # "[vqa] Based on the image, respond to this question with a short answer: {}" + # ] + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + # 'Question: {}', + # 'Answer the question: {}', + '{} A short answer to the question is ', + 'Question: {} Short answer:', + # 'Question: {} Answer: ', + # 'Based on the image, respond to this question with a short answer: {}', ] + self.source = 'gqa' + def __getitem__(self, index): ann = self.annotation[index] @@ -47,14 +61,99 @@ class GQADataset(VQADataset, __DisplMixin): image = self.vis_processor(image) question = self.text_processor(ann["question"]) - instruction = random.choice(self.instruction_pool).format(question) - instruction = " {} ".format(instruction) - + # instruction = random.choice(self.instruction_pool).format(question) + # instruction = " {} ".format(instruction) + answers = self.text_processor(ann["answer"]) + q_input = question + llm_input = random.choice(self.instruction_pool).format(question) + + return { + "image": image, + 'image_id': ann["image"], + "text_input": question, + # "text_output": ann["fullAnswer"], + "text_output": answers, + # "instruction_input": instruction, + "q_input": llm_input, + # "q_input": q_input, + "llm_input": llm_input, + "gt_answers": answers, + "source": "gqa", + } + +class GQAEvalDataset(VQAEvalDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. gqa/images/) + ann_root (string): directory to store the annotation file + """ + + self.vis_root = vis_root + + self.annotation = json.load(open(ann_paths[0])) + + ## TODO: support inference method == 'ranking' + answer_list_path = ann_paths[1] if len(ann_paths) > 1 else '' + if os.path.exists(answer_list_path): + self.answer_list = json.load(open(answer_list_path)) + else: + self.answer_list = None + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + '{} A short answer to the question is ', + 'Question: {} Short answer:', + ] + + self.annotation_add = self.get_data() + self.source = 'gqa' + self._add_instance_ids() + + def get_data(self): + ann_instruct = list() + for i in range(len(self.annotation)): + ann = self.annotation[i].copy() + j = i % len(self.instruction_pool) + question = self.text_processor(ann["question"]) + llm_input = self.instruction_pool[j].format(question) + ann['llm_input'] = llm_input + ann_instruct.append(ann) + np.random.seed(10) + np.random.shuffle(ann_instruct) + return ann_instruct + + + def __getitem__(self, index): + ann = self.annotation_add[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + answer = ann.get("answer", "") + fullAnswer = ann.get("fullAnswer","") + llm_input = ann.get("llm_input",random.choice(self.instruction_pool).format(question)) + q_input = question return { "image": image, - "instruction_input": instruction, - "answer": answers, - } - + "image_id": ann["image"], + "text_input": question, + "gt_answers": answer, + "fullAnswer": fullAnswer, + "text_output": answer, + # "q_input": q_input, + "q_input": llm_input, + "llm_input": llm_input, + "question_id": ann["question_id"], + # "instance_id": ann["instance_id"], + "source": "gqa", + } \ No newline at end of file diff --git a/minigpt4/datasets/datasets/llava_dataset.py b/minigpt4/datasets/datasets/llava_dataset.py index 2766189..1803959 100755 --- a/minigpt4/datasets/datasets/llava_dataset.py +++ b/minigpt4/datasets/datasets/llava_dataset.py @@ -26,6 +26,7 @@ class LlavaDetailDataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor + self.source = 'llava_detail' with open(ann_path, 'r') as f: self.ann = json.load(f) @@ -44,15 +45,19 @@ class LlavaDetailDataset(Dataset): answer = info['conversations'][1]['value'] instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip() - instruction = ' {} '.format(self.text_processor(instruction)) + question = self.text_processor(instruction) return { "image": image, - "instruction_input": instruction, - "answer": answer, "image_id": info['id'], + "q_input": question, + "llm_input": question, + "text_input": question, + "text_output": answer, + "answer": answer, + "source": 'llava_detail', } - + class LlavaReasonDataset(Dataset): def __init__(self, vis_processor, text_processor, vis_root, ann_path): """ @@ -66,6 +71,7 @@ class LlavaReasonDataset(Dataset): with open(ann_path, 'r') as f: self.ann = json.load(f) + self.source = 'llava_reason' def __len__(self): return len(self.ann) @@ -81,18 +87,19 @@ class LlavaReasonDataset(Dataset): answer = info['conversations'][1]['value'] instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip() - instruction = ' {} '.format(self.text_processor(instruction)) + question = self.text_processor(instruction) return { "image": image, - "instruction_input": instruction, - "answer": answer, "image_id": info['id'], + "q_input": question, + "llm_input": question, + "text_input": question, + "text_output": answer, + "answer": answer, + "source": 'llava_reason', } - - - - + class LlavaConversationDataset(Dataset): def __init__(self, vis_processor, text_processor, vis_root, ann_path): """ @@ -103,28 +110,126 @@ class LlavaConversationDataset(Dataset): self.vis_processor = vis_processor self.text_processor = text_processor + self.source = 'llava_conver' self.ann=[] - with open(ann_path, 'r') as f: self.ann = json.load(f) - self.connect_sym = "!@#" - def __len__(self): return len(self.ann) def __getitem__(self, index): info = self.ann[index] - - image_file = 'COCO_train2014_{}.jpg'.format(info['id']) + + image_file = info['image'].split('/')[-1] image_path = os.path.join(self.vis_root, image_file) image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) + question = self.text_processor(info['text_input']) + answer = self.text_processor(info['text_output']) + + return { + "image": image, + "image_id": info['id'], + "q_input": question, + "llm_input": question, + "text_input": question, + "text_output": answer, + "answer": answer, + "source": 'llava_conver', + } + +class LlavaPretrainDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + self.source = 'llava_pretrain' + + self.ann=[] + + with open(ann_path, 'r') as f: + self.ann = json.load(f) + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + info = self.ann[index] + + image_path = os.path.join(self.vis_root, info['image']) + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + + instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip() + answer = info['conversations'][1]['value'] + + instruction = self.text_processor(instruction) + answer = self.text_processor(answer) + + return { + "image": image, + "image_id": info['id'], + "q_input": instruction, + "llm_input": instruction, + "text_input": instruction, + "text_output": answer, + "answer": answer, + "source": 'llava_pretrain', + } + + +class LlavaMixDataset(Dataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_path): + """ + vis_root (string): Root directory of images (e.g. coco/images/) init with build_datasets() in data_builder + ann_root (string): directory to store the annotation file + """ + self.vis_root = vis_root + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.tmp_ann=[] + + with open(ann_path, 'r') as f: + self.tmp_ann = json.load(f) + + self.ann = self.filter_process_data() + self.source = 'llava_mix' + + def build_image_path(self, info): + if 'image' not in info.keys(): # pure text data + return None + + image_name = info['image'] + if 'coco' in image_name: + image_file = 'COCO_train2014_{}.jpg'.format(info['id']) + image_path = os.path.join(self.vis_root['coco'], image_file) + elif 'gqa' in image_name: + image_path = os.path.join(self.vis_root['gqa'], '{}.jpg'.format(info['id'])) + elif 'ocr' in image_name: + image_path = os.path.join(self.vis_root['ocr'], '{}.jpg'.format(info['id'])) + elif 'textvqa' in image_name: + image_path = os.path.join(self.vis_root['text'], '{}.jpg'.format(info['id'])) + elif 'vg' in image_name: + # TODO + # image_path = os.path.join(self.vis_root['vg'], '{}.jpg'.format(info['id'])) + image_path = None + + return image_path + + def process_convers(self, info): first_instruction = info['conversations'][0]['value'].replace('', '').replace('\n', '').strip() - first_instruction = ' {} '.format(first_instruction) + # first_instruction = ' {} '.format(first_instruction) questions = [first_instruction] answers = [] @@ -137,14 +242,46 @@ class LlavaConversationDataset(Dataset): human_instruction = item["value"]+" " questions.append(human_instruction) - questions = self.connect_sym.join(questions) - answers = self.connect_sym.join(answers) + return questions, answers + def filter_process_data(self): + ann = list() + for info in self.tmp_ann: + image_path = self.build_image_path(info) + if image_path != None: + questions, answers = self.process_convers(info) + assert len(questions) == len(answers) + for i in range(len(questions)): + ann.append({ + 'question_id': '{}_{}'.format(info['id'],str(i)), + 'image_id': str(info['id']), + 'image_path': image_path, + 'text_input': questions[i], + 'text_output': answers[i], + }) + return ann + + + def __len__(self): + return len(self.ann) + + def __getitem__(self, index): + + info = self.ann[index] + try: + image = Image.open(info['image_path']).convert("RGB") + except: + image = Image.open(info['image_path'].replace('train','val')).convert("RGB") + image = self.vis_processor(image) + instruction = 'Answer the question using a single word or phrase.' + return { "image": image, - "conv_q": questions, - 'conv_a': answers, - "image_id": info['id'], - "connect_sym": self.connect_sym + "image_id": info['image_id'], + "text_input": info['text_input'], + "text_output": info['text_output'], + "q_input": info['text_input'].replace(instruction,''), + "llm_input": info['text_input'], + "question_id": info["question_id"], } \ No newline at end of file diff --git a/minigpt4/datasets/datasets/ocrvqa_dataset.py b/minigpt4/datasets/datasets/ocrvqa_dataset.py index 00ce03d..020858e 100755 --- a/minigpt4/datasets/datasets/ocrvqa_dataset.py +++ b/minigpt4/datasets/datasets/ocrvqa_dataset.py @@ -30,10 +30,14 @@ class OCRVQADataset(Dataset): self.text_processor = text_processor self.data = self.create_data(ann_path) - self.instruction_pool =[ - "[vqa] {}", - "[vqa] Based on the image, respond to this question with a short answer: {}" + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + '{} A short answer to the question is ', + 'Question: {} Short answer:', ] + self.source = 'ocrvqa' def create_data(self, ann_path): processed_data = [] @@ -41,19 +45,21 @@ class OCRVQADataset(Dataset): data = json.load(f) for k in data.keys(): if data[k]['split'] != 1: continue # 1 for training, 2 for validation, 3 for test - ext = os.path.splitext(data[k]['imageURL'])[1] + # ext = os.path.splitext(data[k]['imageURL'])[1] + ext = '.jpg' imageFile = k + ext - assert len(data[k]['questions']) == len(data[k]['answers']) - for q, a in zip(data[k]['questions'], data[k]['answers']): - processed_data.append( - {'question': q, - 'answer': a, - 'image_path': imageFile, - 'image_id': k, - 'title': data[k]['title'], - 'genre': data[k]['genre'], - } - ) + if os.path.exists(os.path.join(self.vis_root, imageFile)): + assert len(data[k]['questions']) == len(data[k]['answers']) + for q, a in zip(data[k]['questions'], data[k]['answers']): + processed_data.append( + {'question': q, + 'answer': a, + 'image_path': imageFile, + 'image_id': k, + 'title': data[k]['title'], + 'genre': data[k]['genre'], + } + ) return processed_data def __len__(self): @@ -66,12 +72,17 @@ class OCRVQADataset(Dataset): question = self.text_processor(sample["question"]) answer = self.text_processor(sample["answer"]) - instruction = random.choice(self.instruction_pool).format(question) - instruction = " {} ".format(instruction) + q_input = question + llm_input = random.choice(self.instruction_pool).format(question) + return { "image": image, - "instruction_input": instruction, + "image_id": sample["image_id"], + "q_input": q_input, + "llm_input": llm_input, + "text_input": question, + "text_output": answer, "answer": answer, - "image_id": sample['image_id'] + "source": 'ocrvqa', } diff --git a/minigpt4/datasets/datasets/ok_vqa_datasets.py b/minigpt4/datasets/datasets/ok_vqa_datasets.py new file mode 100755 index 0000000..ae09a29 --- /dev/null +++ b/minigpt4/datasets/datasets/ok_vqa_datasets.py @@ -0,0 +1,189 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import os +import json +import random +import numpy as np +from PIL import Image + +from minigpt4.datasets.datasets.vqa_datasets import VQADataset, VQAEvalDataset + +from collections import OrderedDict + + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "answers": "; ".join(ann["answer"]), + "image": sample["image"], + } + ) + + +class OKVQADataset(VQADataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + '{} A short answer to the question is ', + 'Question: {} Short answer:', + ] + + exist_annotation = [] + for ann in self.annotation: + # image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image_path = os.path.join(self.vis_root, ann["image"]) + if os.path.exists(image_path): + exist_annotation.append(ann) + self.annotation = exist_annotation + self.source = 'okvqa' + + + def get_data(self, index): + ann = self.annotation[index] + + # image_path = os.path.join(self.vis_root, ann["image"].split('/')[-1]) + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + question_id = ann["question_id"] + + answer_weight = {} + for answer in ann["answer"]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(ann["answer"]) + else: + answer_weight[answer] = 1 / len(ann["answer"]) + + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + + answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + + + return { + "image": image, + "image_id": ann["image"], + "question": question, + "question_id": question_id, + "answer": answer, + } + + def __getitem__(self, index): + data = self.get_data(index) + question = data['question'] + answer = self.text_processor(data['answer']) + + q_input = question + llm_input = random.choice(self.instruction_pool).format(question) + + return { + "image": data['image'], + "image_id": data["image_id"], + "question_id": data["question_id"], + # "instruction_input": instruction, + "q_input": llm_input, + # "q_input": q_input, + "llm_input": llm_input, + "text_input": question, + "text_output": answer, + "answer": answer, + "source": 'okvqa', + } + + +class OKVQAEvalDataset(VQAEvalDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + """ + vis_root (string): Root directory of images (e.g. coco/images/) + ann_root (string): directory to store the annotation file + """ + + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + '{} A short answer to the question is ', + 'Question: {} Short answer:', + ] + self.vis_root = vis_root + + self.annotation = json.load(open(ann_paths[0])) + exist_annotation = [] + for ann in self.annotation: + image_path = os.path.join(self.vis_root, ann["image"]) + if os.path.exists(image_path): + exist_annotation.append(ann) + self.annotation = exist_annotation + + answer_list_path = ann_paths[1] + if os.path.exists(answer_list_path): + self.answer_list = json.load(open(answer_list_path)) + else: + self.answer_list = None + + try: + self.coco_fmt_qust_file = ann_paths[2] + self.coco_fmt_anno_file = ann_paths[3] + except IndexError: + self.coco_fmt_qust_file = None + self.coco_fmt_anno_file = None + + self.vis_processor = vis_processor + self.text_processor = text_processor + + self.source = 'okvqa' + self.annotation_add = self.get_data() + + def get_data(self): + ann_instruct = list() + for i in range(len(self.annotation)): + ann = self.annotation[i].copy() + j = i % len(self.instruction_pool) + question = self.text_processor(ann["question"]) + llm_input = self.instruction_pool[j].format(question) + ann['llm_input'] = llm_input + ann_instruct.append(ann) + np.random.seed(10) + np.random.shuffle(ann_instruct) + return ann_instruct + + def __getitem__(self, index): + ann = self.annotation_add[index] + + image_path = os.path.join(self.vis_root, ann["image"]) + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + + question = self.text_processor(ann["question"]) + q_input = question + llm_input = ann.get("llm_input",random.choice(self.instruction_pool).format(question)) + + return { + "image": image, + "image_id": ann["image"], + 'image_path': image_path, + "question_id": ann["question_id"], + "question": question, + # "q_input": q_input, + "q_input": llm_input, + "llm_input": llm_input, + "text_input": question, + "source": 'okvqa', + } diff --git a/minigpt4/datasets/datasets/text_caps.py b/minigpt4/datasets/datasets/text_caps.py index 47a87f1..edbcd4a 100755 --- a/minigpt4/datasets/datasets/text_caps.py +++ b/minigpt4/datasets/datasets/text_caps.py @@ -15,20 +15,17 @@ from torch.utils.data import Dataset import webdataset as wds from minigpt4.datasets.datasets.base_dataset import BaseDataset -from minigpt4.datasets.datasets.caption_datasets import CaptionDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset, CaptionEvalDataset -class TextCapDataset(Dataset): - def __init__(self, vis_processor, text_processor, vis_root, ann_path): +class TextCapDataset(BaseDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): """ vis_root (string): Root directory of images (e.g. coco/images/) ann_root (string): directory to store the annotation file """ - self.vis_root = vis_root - - self.vis_processor = vis_processor - self.text_processor = text_processor + super().__init__(vis_processor, text_processor, vis_root, ann_paths) self.instruction_pool = [ 'Briefly describe this image.', @@ -49,19 +46,12 @@ class TextCapDataset(Dataset): 'Using language, provide a short account of the image.', 'Use a few words to illustrate what is happening in the picture.', ] - - with open(ann_path, 'r') as f: - self.ann = json.load(f) - - - def __len__(self): - return len(self.ann["data"]) - + self.source = 'text_cap' def __getitem__(self, index): - info = self.ann["data"][index] + info = self.annotation[index] - image_file = '{}.jpg'.format(info['image_id']) + image_file = info['image_path'] image_path = os.path.join(self.vis_root, image_file) image = Image.open(image_path).convert("RGB") @@ -69,9 +59,67 @@ class TextCapDataset(Dataset): caption = info["caption_str"] caption = self.text_processor(caption) - instruction = " [caption] {} ".format(random.choice(self.instruction_pool)) + instruction = random.choice(self.instruction_pool) + q_input = instruction + llm_input = instruction + return { "image": image, - "instruction_input": instruction, + "image_id": info["image_name"], "answer": caption, + "q_input": q_input, + "llm_input": llm_input, + "text_input": llm_input, + "text_output": caption, + "source": self.source, } + +class TextCapEvalDataset(CaptionEvalDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.instruction_pool = [ + 'Briefly describe this image.', + 'Provide a concise depiction of this image.', + 'Present a short description of this image.', + 'Summarize this image in a few words.', + 'A short image caption:', + 'A short image description:', + 'A photo of ', + 'An image that shows ', + 'Write a short description for the image. ', + 'Write a description for the photo.', + 'Provide a description of what is presented in the photo.', + 'Briefly describe the content of the image.', + 'Can you briefly explain what you see in the image?', + 'Could you use a few words to describe what you perceive in the photo?', + 'Please provide a short depiction of the picture.', + 'Using language, provide a short account of the image.', + 'Use a few words to illustrate what is happening in the picture.', + ] + self.source = 'text_cap' + + def __getitem__(self, index): + info = self.annotation[index] + + # image_file = '{}.jpg'.format(info['image_id']) + image_file = info['image_path'] + + image_path = os.path.join(self.vis_root, image_file) + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + + caption = info["caption_str"] + caption = self.text_processor(caption) + instruction = random.choice(self.instruction_pool) + q_input = instruction + llm_input = instruction + + return { + "image": image, + "image_id": info["image_name"], + "text_input":llm_input, + "q_input": q_input, + "llm_input": llm_input, + "source": self.source, + } \ No newline at end of file diff --git a/minigpt4/datasets/datasets/text_vqa_dataset.py b/minigpt4/datasets/datasets/text_vqa_dataset.py new file mode 100755 index 0000000..d6c6a59 --- /dev/null +++ b/minigpt4/datasets/datasets/text_vqa_dataset.py @@ -0,0 +1,127 @@ +import os +import json +import pickle +import random +import time +import itertools + +import numpy as np +from PIL import Image +import skimage.io as io +import matplotlib.pyplot as plt +from matplotlib.collections import PatchCollection +from matplotlib.patches import Polygon, Rectangle +from torch.utils.data import Dataset +import webdataset as wds + +from minigpt4.datasets.datasets.base_dataset import BaseDataset +from minigpt4.datasets.datasets.caption_datasets import CaptionDataset, CaptionEvalDataset + +from collections import OrderedDict + +class __DisplMixin: + def displ_item(self, index): + sample, ann = self.__getitem__(index), self.annotation[index] + + return OrderedDict( + { + "file": ann["image"], + "question": ann["question"], + "question_id": ann["question_id"], + "answers": "; ".join(ann["answer"]), + "image": sample["image"], + } + ) + + +class TextVQADataset(BaseDataset, __DisplMixin): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + '{} A short answer to the question is ', + 'Question: {} Short answer:', + ] + self.source = 'text_vqa' + + def __getitem__(self, index): + ann = self.annotation[index] + + image_path = os.path.join(self.vis_root, "train/{}.jpg".format(ann["image_id"])) + image = Image.open(image_path).convert("RGB") + + image = self.vis_processor(image) + question = self.text_processor(ann["question"]) + + q_input = question + llm_input = random.choice(self.instruction_pool).format(question) + + answer_weight = {} + for answer in ann["answers"]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(ann["answers"]) + else: + answer_weight[answer] = 1 / len(ann["answers"]) + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + + return { + "image": image, + 'image_id': ann["image_id"], + "text_input": question, + "text_output": answer, + "q_input": q_input, + "llm_input": llm_input, + "gt_answers": answer, + "source": "text_vqa", + } + +class TextVQAEvalDataset(CaptionEvalDataset): + def __init__(self, vis_processor, text_processor, vis_root, ann_paths): + super().__init__(vis_processor, text_processor, vis_root, ann_paths) + + self.instruction_pool =[ + '{}', + 'Q: {} A: ', + 'Based on the image, respond to this question with a short answer: {}', + '{} A short answer to the question is ', + 'Question: {} Short answer:', + ] + self.source = 'text_vqa' + + def __getitem__(self, index): + info = self.annotation[index] + + image_path = os.path.join(self.vis_root, "train/{}.jpg".format(info["image_id"])) + + image = Image.open(image_path).convert("RGB") + image = self.vis_processor(image) + question = self.text_processor(info["question"]) + + q_input = question + llm_input = random.choice(self.instruction_pool).format(question) + + answer_weight = {} + for answer in info["answers"]: + if answer in answer_weight.keys(): + answer_weight[answer] += 1 / len(info["answers"]) + else: + answer_weight[answer] = 1 / len(info["answers"]) + answers = list(answer_weight.keys()) + weights = list(answer_weight.values()) + answer = random.choices(answers, weights=weights, k=1)[0] # random sample an answer according to weights + + return { + "image": image, + "image_id": info["image_id"], + "question": question, + # "q_input": llm_input, + "q_input": q_input, + "llm_input": llm_input, + "text_input": question, + "gt_answers": answer, + "source": 'text_vqa', + } \ No newline at end of file diff --git a/minigpt4/datasets/datasets/vqa_datasets.py b/minigpt4/datasets/datasets/vqa_datasets.py index 5cdc0fa..4a9c2c7 100755 --- a/minigpt4/datasets/datasets/vqa_datasets.py +++ b/minigpt4/datasets/datasets/vqa_datasets.py @@ -87,8 +87,16 @@ class VizWizEvalData(torch.utils.data.Dataset): image = Image.open(image_path).convert('RGB') image = self.vis_processor(image) # question = f"[vqa] Based on the image, respond to this question with a short answer: {question} " - question = f"[vqa] Based on the image, respond to this question with a short answer: {question} and reply 'unanswerable' if you could not answer it" - return image, question, answers + # question = f"Based on the image, respond to this question with a short answer: {question} and reply 'unanswerable' if you could not answer it" + question = question + sample = { + 'image':image, + 'q_input': question, + 'llm_input': question, + 'image_id': img_id, + 'gt_ans': answers, + } + return sample class AOKVQADAEvalData(torch.utils.data.Dataset): def __init__(self, loaded_data, vis_processor, root_path): @@ -129,9 +137,9 @@ class AOKVQAMCEvalData(torch.utils.data.Dataset): image = Image.open(image_path).convert('RGB') image = self.vis_processor(image).half().cuda() candidates=data['choices'] - # question = f"Given this image, choose one correct answer from {candidates} for this question: {question} " - question = f"[vqa] Based on the image, respond to this question with a short answer: {question}" - # question = f"[vqa] {question} " + question = f"Given this image, choose one correct answer from {candidates} for this question: {question} " + # question = f"Based on the image, respond to this question with a short answer: {question}" + # question = f"{question} " return image, question, question_id, candidates class IconQAEvalData(torch.utils.data.Dataset): @@ -152,33 +160,9 @@ class IconQAEvalData(torch.utils.data.Dataset): image = self.vis_processor(image).half().cuda() candidates = '_'.join(data['choices']) answer = data['answer'] - # question = f"Given this image, choose one correct answer from {candidates} for this question: {question} " question = f"[vqa] Based on the image, respond to this question with a short answer: {question}" - # question = f"[vqa] {question} " return image, question, candidates, answer -class GQAEvalData(torch.utils.data.Dataset): - def __init__(self, loaded_data, vis_processor, root_path): - self.loaded_data = loaded_data - self.root_path = root_path - self.vis_processor = vis_processor - - def __len__(self): - return len(self.loaded_data) - - def __getitem__(self, idx): - ann = self.loaded_data[idx] - image_id = ann["image"] - image_path = os.path.join(self.root_path, f"{image_id}") - image = Image.open(image_path).convert("RGB") - image = self.vis_processor(image) - question = ann["question"] - # question = f'Question: {question} Short answer: ' - question = f"[vqa] Based on the image, respond to this question with a short answer: {question}" - # question = f"[vqa] {question} " - labels = ann["answer"] - - return image, question, labels class HMEvalData(torch.utils.data.Dataset): def __init__(self, loaded_data, vis_processor, root_path): @@ -196,10 +180,18 @@ class HMEvalData(torch.utils.data.Dataset): image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) question = ann["text"] - question = f"[vqa] This is an image with: \"{question}\" written on it. Is it hateful? Answer:" + question = f"This is an image writting '{question}'. Is this image hateful? Answer yes or no. Answer:" labels = ann["label"] + + sample = { + 'image':image, + 'q_input': question, + 'llm_input': question, + 'image_id': image_id, + 'gt_ans': labels, + } - return image, question, labels + return sample class VSREvalData(torch.utils.data.Dataset): def __init__(self, loaded_data, vis_processor, root_path): @@ -216,8 +208,15 @@ class VSREvalData(torch.utils.data.Dataset): image = Image.open(image_path).convert("RGB") image = self.vis_processor(image) question = ann["caption"] - question = f'[vqa] Based on the image, is this statement true or false? {question}' - question_id = ann["image"].split('.')[0] + question = f'Based on the image, is this statement true or false? {question}' labels = 'true' if ann["label"] == 1 else 'false' - return image, question, labels \ No newline at end of file + sample = { + 'image':image, + 'q_input': question, + 'llm_input': question, + 'image_id': ann["image"], + 'gt_ans': labels, + } + + return sample \ No newline at end of file diff --git a/eval_configs/minigpt4_eval.yaml b/minigpt4/eval_configs/minigpt4_eval.yaml similarity index 100% rename from eval_configs/minigpt4_eval.yaml rename to minigpt4/eval_configs/minigpt4_eval.yaml diff --git a/eval_configs/minigpt4_llama2_eval.yaml b/minigpt4/eval_configs/minigpt4_llama2_eval.yaml similarity index 100% rename from eval_configs/minigpt4_llama2_eval.yaml rename to minigpt4/eval_configs/minigpt4_llama2_eval.yaml diff --git a/minigpt4/eval_configs/minigptv2_benchmark_evaluation.yaml b/minigpt4/eval_configs/minigptv2_benchmark_evaluation.yaml new file mode 100644 index 0000000..c0e3a26 --- /dev/null +++ b/minigpt4/eval_configs/minigptv2_benchmark_evaluation.yaml @@ -0,0 +1,79 @@ +model: + arch: minigpt_v2 + model_type: pretrain + max_txt_len: 500 + end_sym: "" + low_resource: False + prompt_template: '[INST] {} [/INST]' + llama_model: "" + ckpt: "" + lora_r: 64 + lora_alpha: 16 + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 448 + text_processor: + train: + name: "blip_caption" + +evaluation_datasets: + refcoco: + eval_file_path: /path/to/eval/annotation/path + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 10 + refcocog: + eval_file_path: /path/to/eval/annotation/path + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 10 + refcoco+: + eval_file_path: /path/to/eval/annotation/path + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 10 + gqa: + eval_file_path: /path/to/eval/annotation/path + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 10 + okvqa: + eval_file_path: /path/to/eval/annotation/path + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 10 + vizwiz: + eval_file_path: /path/to/eval/annotation/path + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 10 + iconvqa: + eval_file_path: /path/to/eval/annotation/path + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 10 + vsr: + eval_file_path: cambridgeltl/vsr_zeroshot + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 10 + hm: + eval_file_path: /path/to/eval/annotation/path + img_path: /path/to/eval/image/path + max_new_tokens: 20 + batch_size: 100 + +run: + task: image_text_pretrain + name: minigptv2_evaluation + save_path: /path/to/save/folder_path + + + + + diff --git a/minigpt4/eval_configs/minigptv2_eval.yaml b/minigpt4/eval_configs/minigptv2_eval.yaml new file mode 100644 index 0000000..46ac2fa --- /dev/null +++ b/minigpt4/eval_configs/minigptv2_eval.yaml @@ -0,0 +1,24 @@ +model: + arch: minigpt_v2 + model_type: pretrain + max_txt_len: 500 + end_sym: "" + low_resource: True + prompt_template: '[INST] {} [/INST]' + ckpt: "please set this value to the path of pretrained checkpoint" + lora_r: 64 + lora_alpha: 16 + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 448 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/minigpt4/eval_scripts/EVAL_README.md b/minigpt4/eval_scripts/EVAL_README.md new file mode 100644 index 0000000..4d6681f --- /dev/null +++ b/minigpt4/eval_scripts/EVAL_README.md @@ -0,0 +1,104 @@ +## Evaluation Instruction for MiniGPT-v2 + +### Data preparation +Images download +Image source | Download path +--- | :---: +OKVQA| annotations    images +gqa | annotations    images +hateful meme | images and annotations +iconqa | images and annotation +vizwiz | images and annotation +RefCOCO | annotations +RefCOCO+ | annotations +RefCOCOg | annotations + +### Evaluation dataset structure + +``` +${MINIGPTv2_EVALUATION_DATASET} +├── gqa +│ └── test_balanced_questions.json +│ ├── testdev_balanced_questions.json +│ ├── gqa_images +├── hateful_meme +│ └── hm_images +│ ├── dev.jsonl +├── iconvqa +│ └── iconvqa_images +│ ├── choose_text_val.json +├── vizwiz +│ └── vizwiz_images +│ ├── val.json +├── vsr +│ └── vsr_images +├── okvqa +│ ├── okvqa_test_split.json +│ ├── mscoco_val2014_annotations_clean.json +│ ├── OpenEnded_mscoco_val2014_questions_clean.json +├── refcoco +│ └── instances.json +│ ├── refs(google).p +│ ├── refs(unc).p +├── refcoco+ +│ └── instances.json +│ ├── refs(unc).p +├── refercocog +│ └── instances.json +│ ├── refs(google).p +│ ├── refs(und).p +... +``` + + +### environment setup + +``` +export PYTHONPATH=$PYTHONPATH:/path/to/directory/of/MiniGPT-4 +``` + +### config file setup + +Set **llama_model** to the path of LLaMA model. +Set **ckpt** to the path of our pretrained model. +Set **eval_file_path** to the path of the annotation files for each evaluation data. +Set **img_path** to the img_path for each evaluation dataset. +Set **save_path** to the save_path for evch evaluation dataset. + +in [eval_configs/minigptv2_benchmark_evaluation.yaml](../eval_configs/minigptv2_benchmark_evaluation.yaml) + + + + +### start evalauting RefCOCO, RefCOCO+, RefCOCOg +port=port_number +cfg_path=/path/to/eval_configs/minigptv2_benchmark_evaluation.yaml + +dataset names: +| refcoco | refcoco+ | refcocog | +| ------- | -------- | -------- | + +``` +torchrun --master-port ${port} --nproc_per_node 1 eval_ref.py \ + --cfg-path ${cfg_path} --dataset refcoco,refcoco+,refcocog --resample +``` + + +### start evaluating visual question answering + +port=port_number +cfg_path=/path/to/eval_configs/minigptv2_benchmark_evaluation.yaml + +dataset names: +| okvqa | vizwiz | iconvqa | gqa | vsr | hm | +| ------- | -------- | -------- |-------- | -------- | -------- | + + +``` +torchrun --master-port ${port} --nproc_per_node 1 eval_vqa.py \ + --cfg-path ${cfg_path} --dataset okvqa,vizwiz,iconvqa,gqa,vsr,hm +``` + + + + diff --git a/minigpt4/eval_scripts/eval_ref.py b/minigpt4/eval_scripts/eval_ref.py new file mode 100644 index 0000000..28d55a2 --- /dev/null +++ b/minigpt4/eval_scripts/eval_ref.py @@ -0,0 +1,128 @@ +import os +import re +import json +import argparse +from collections import defaultdict +import random +import numpy as np +from PIL import Image +from tqdm import tqdm +import torch +from torch.utils.data import DataLoader +from minigpt4.common.config import Config +from minigpt4.common.eval_utils import prepare_texts, init_model, eval_parser, computeIoU +from minigpt4.conversation.conversation import CONV_VISION_minigptv2 + +from minigpt4.datasets.datasets.coco_caption import RefCOCOEvalData + +def list_of_str(arg): + return list(map(str, arg.split(','))) + +parser = eval_parser() +parser.add_argument("--dataset", type=list_of_str, default='refcoco', help="dataset to evaluate") +parser.add_argument("--res", type=float, default=100.0, help="resolution used in refcoco") +parser.add_argument("--resample", action='store_true', help="resolution used in refcoco") +args = parser.parse_args() + +cfg = Config(args) + +eval_dict = {'refcoco': ['val','testA','testB'], + 'refcoco+': ['val','testA','testB'], + 'refcocog': ['val','test']} + + +model, vis_processor = init_model(args) +model.eval() +CONV_VISION = CONV_VISION_minigptv2 +conv_temp = CONV_VISION.copy() +conv_temp.system = "" + +# +model.eval() +save_path = cfg.run_cfg.save_path + + + +for dataset in args.dataset: + for split in eval_dict[dataset]: + + eval_file_path = cfg.evaluation_datasets_cfg[dataset]["eval_file_path"] + img_path = cfg.evaluation_datasets_cfg[dataset]["img_path"] + batch_size = cfg.evaluation_datasets_cfg[dataset]["batch_size"] + max_new_tokens = cfg.evaluation_datasets_cfg[dataset]["max_new_tokens"] + + with open(os.path.join(eval_file_path,f"{dataset}/{dataset}_{split}.json"), 'r') as f: + refcoco = json.load(f) + + data = RefCOCOEvalData(refcoco, vis_processor, img_path) + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + minigpt4_predict = defaultdict(list) + resamples = [] + + for images, questions, img_ids in tqdm(eval_dataloader): + texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template + answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) + for answer, img_id, question in zip(answers, img_ids, questions): + answer = answer.replace("","").replace(" ","").strip() + pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}' + if re.match(pattern, answer): + minigpt4_predict[img_id].append(answer) + else: + resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]}) + if args.resample: + for i in range(20): + data = RefCOCOEvalData(resamples, vis_processor, img_path) + resamples = [] + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + for images, questions, img_ids in tqdm(eval_dataloader): + texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template + answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) + for answer, img_id, question in zip(answers, img_ids, questions): + answer = answer.replace("","").replace(" ","").strip() + pattern = r'\{<\d{1,3}><\d{1,3}><\d{1,3}><\d{1,3}>\}' + if re.match(pattern, answer) or i == 4: + minigpt4_predict[img_id].append(answer) + else: + resamples.append({'img_id': img_id, 'sents': [question.replace('[refer] give me the location of','').strip()]}) + + if len(resamples) == 0: + break + + file_save_path = os.path.join(save_path,f"{args.dataset}_{split}.json") + with open(file_save_path,'w') as f: + json.dump(minigpt4_predict, f) + + count=0 + total=len(refcoco) + res=args.res + refcoco_dict = defaultdict() + for item in refcoco: + refcoco_dict[item['img_id']] = item + for img_id in refcoco_dict: + item = refcoco_dict[img_id] + bbox = item['bbox'] + outputs = minigpt4_predict[img_id] + for output in outputs: + try: + integers = re.findall(r'\d+', output) + pred_bbox = [int(num) for num in integers] + height = item['height'] + width = item['width'] + pred_bbox[0] = pred_bbox[0] / res * width + pred_bbox[1] = pred_bbox[1] / res * height + pred_bbox[2] = pred_bbox[2] / res * width + pred_bbox[3] = pred_bbox[3] / res * height + + gt_bbox = [0,0,0,0] + gt_bbox[0] = bbox[0] + gt_bbox[1] = bbox[1] + gt_bbox[2] = bbox[0] + bbox[2] + gt_bbox[3] = bbox[1] + bbox[3] + + iou_score = computeIoU(pred_bbox, gt_bbox) + if iou_score > 0.5: + count+=1 + except: + continue + + print(f'{dataset} {split}:', count / total * 100, flush=True) diff --git a/minigpt4/eval_scripts/eval_vqa.py b/minigpt4/eval_scripts/eval_vqa.py new file mode 100644 index 0000000..19d51b0 --- /dev/null +++ b/minigpt4/eval_scripts/eval_vqa.py @@ -0,0 +1,346 @@ +# python eval_vqa.py --dataset vizwiz +import os +import re +import json +import argparse +from collections import defaultdict +import random + +import numpy as np +from PIL import Image +from tqdm import tqdm +import torch +from torch.utils.data import DataLoader +import torch.backends.cudnn as cudnn +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") + +from minigpt4.common.logger import setup_logger +from minigpt4.common.dist_utils import get_rank +from minigpt4.common.registry import registry +from minigpt4.datasets.datasets.vqa_datasets import OKVQAEvalData,VizWizEvalData,IconQAEvalData,VSREvalData,HMEvalData +from minigpt4.common.vqa_tools.vqa import VQA +from minigpt4.common.vqa_tools.vqa_eval import VQAEval +from minigpt4.common.config import Config + + +def list_of_str(arg): + return list(map(str, arg.split(','))) + +def eval_parser(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument( + "--cfg-path", + default="/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/projects/qformer_moe_vicuna/eval/vqa_benchmark_evaluation.yaml", + help="path to configuration file.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + return parser + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + +def init_model(cfg, device): + print('Initialization Model') + model_config = cfg.model_cfg + model_cls = registry.get_model_class(model_config.arch) + model = model_cls.from_config(model_config).to(device) + +# import pudb; pudb.set_trace() + key = list(cfg.datasets_cfg.keys())[0] + vis_processor_cfg = cfg.datasets_cfg.get(key).vis_processor.train + vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg) + txt_processor_cfg = cfg.datasets_cfg.get(key).text_processor.train + text_processor = registry.get_processor_class(txt_processor_cfg.name).from_config(txt_processor_cfg) + print('Initialization Finished') + return model, vis_processor, text_processor + +parser = eval_parser() +parser.add_argument("--dataset", type=list_of_str, default=['vizwiz','hm'], help="dataset to evaluate") +args = parser.parse_args() +cfg = Config(args) +setup_seeds(cfg) +print(cfg._convert_node_to_json(cfg.config)) +setup_logger() +device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu") + +model, vis_processor, _ = init_model(cfg, device) +model.eval() + +run_cfg = cfg.run_cfg +save_path = cfg.run_cfg.save_path +num_beams = run_cfg.get("num_beams", 3) +max_len = run_cfg.get("max_len", 20) +min_len = run_cfg.get("min_len", 1) +inference_method = run_cfg.get("inference_method", "rank") +num_ans_candidates = run_cfg.get("num_ans_candidates", 128) +prompt = run_cfg.get("prompt", "") +if not os.path.exists(save_path): + os.mkdir(save_path) + +if 'vizwiz' in args.dataset: + + eval_file_path = cfg.evaluation_datasets_cfg["vizwiz"]["eval_file_path"] + img_path = cfg.evaluation_datasets_cfg["vizwiz"]["img_path"] + batch_size = cfg.evaluation_datasets_cfg["vizwiz"]["batch_size"] + max_new_tokens = cfg.evaluation_datasets_cfg["vizwiz"]["max_new_tokens"] + + vizwiz = json.load(open(eval_file_path, 'r')) + + data = VizWizEvalData(vizwiz, vis_processor, img_path) + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + predicts = [] + total_acc = [] + for samples in tqdm(eval_dataloader): + samples['image'] = samples['image'].half().to(device) + texts = samples['q_input'] + gt_answers = samples['gt_ans'] + image_ids = samples['image_id'] + + answers = model.predict_answers( + samples=samples, + inference_method=inference_method, + num_beams=num_beams, + max_len=max_len, + min_len=min_len, + num_ans_candidates=num_ans_candidates, + prompt=prompt, + ) + + for i in range(len(answers)): + question, answer, gt_answer, img_id = texts[i], answers[i], gt_answers[i], image_ids[i] + result = {'img_id':img_id, 'question':question} + result['answer'] = answer.replace('','').strip() + count=0 + gt_answer = gt_answer.split('_') + for gt in gt_answer: + if gt.lower() == answer.lower(): + count += 1 + acc = min(count/3.0, 1.0) + total_acc.append(acc) + result['gt_ans'] = gt_answer + predicts.append(result) + + vizwiz_acc = np.average(total_acc)* 100.0 + print('vizwiz Acc: ', vizwiz_acc, flush=True) + + file_save_path = os.path.join(save_path, "vizwiz.json") + with open(file_save_path,'a+') as f: + json.dump(predicts, f) + + with open(os.path.join(save_path, f"evaluate_vizwiz.txt"), "a") as f: + f.write(json.dumps({'agg_metrics': vizwiz_acc}) + "\n") + +if 'okvqa' in args.dataset: + + eval_file_path = cfg.evaluation_datasets_cfg["okvqa"]["eval_file_path"] + img_path = cfg.evaluation_datasets_cfg["okvqa"]["img_path"] + batch_size = cfg.evaluation_datasets_cfg["okvqa"]["batch_size"] + max_new_tokens = cfg.evaluation_datasets_cfg["okvqa"]["max_new_tokens"] + + + evaluation_annntation_path = os.path.join(eval_file_path, "okvqa_test_split.json") + with open(evaluation_annntation_path) as f: + ok_vqa_test_split = json.load(f) + + data = OKVQAEvalData(ok_vqa_test_split, vis_processor, img_path) + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + minigpt4_predict = [] + + for images, questions, question_ids, img_ids in eval_dataloader: + texts = prepare_texts(questions, conv_temp) # warp the texts with conversation template + answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) + + for answer, question_id, question, img_id in zip(answers, question_ids, questions, img_ids): + result = dict() + answer = answer.lower().replace('','').strip() + result['answer'] = answer + result['question_id'] = int(question_id) + minigpt4_predict.append(result) + + file_save_path= os.path.join(save_path,"okvqa.json") + with open(file_save_path,'w') as f: + json.dump(minigpt4_predict, f) + + annFile = os.path.join(eval_file_path,"mscoco_val2014_annotations_clean.json") + quesFile = os.path.join(eval_file_path,"OpenEnded_mscoco_val2014_questions_clean.json" ) + + vqa = VQA(annFile, quesFile) + vqaRes = vqa.loadRes(file_save_path, quesFile) + + vqaEval = VQAEval(vqa, vqaRes, n=2) + vqaEval.evaluate() + print ("Overall OKVQA Accuracy is: %.02f\n" %(vqaEval.accuracy['overall']), flush=True) + +if 'iconvqa' in args.dataset: + + eval_file_path = cfg.evaluation_datasets_cfg["iconvqa"]["eval_file_path"] + img_path = cfg.evaluation_datasets_cfg["iconvqa"]["img_path"] + batch_size = cfg.evaluation_datasets_cfg["iconvqa"]["batch_size"] + max_new_tokens = cfg.evaluation_datasets_cfg["iconvqa"]["max_new_tokens"] + + iconqa_text_val = json.load(open(eval_file_path,"r")) + + data = IconQAEvalData(iconqa_text_val, vis_processor, img_path) + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + + count = 0 + for images, texts, candidates, answers in tqdm(eval_dataloader): + candidates = [candidate.split('_') for candidate in candidates] + num_cand = [len(candidate) for candidate in candidates] + for candidate in candidates: + candidate.extend(['none'] * (max(num_cand) - len(candidate))) + candidates = [list(x) for x in zip(*candidates)] + instructions = ["[INST] {} [/INST]".format(text) for text in texts] + answer_ranks = model.multi_select(images, instructions, candidates, num_cand=num_cand) + for idx, answer in enumerate(answers): + if answer_ranks[idx][0] == answer: + count += 1 + + print('iconqa Acc: ', count / len(iconqa_text_val) * 100.0, flush=True) + +if 'gqa' in args.dataset: + + eval_file_path = cfg.evaluation_datasets_cfg["gqa"]["eval_file_path"] + img_path = cfg.evaluation_datasets_cfg["gqa"]["img_path"] + batch_size = cfg.evaluation_datasets_cfg["gqa"]["batch_size"] + max_new_tokens = cfg.evaluation_datasets_cfg["gqa"]["max_new_tokens"] + + gqa = json.load(open(eval_file_path)) + data = GQAEvalData(gqa, vis_processor, img_path) + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + count=0 + total=0 + minigpt4_predict = [] + for images, texts, labels in tqdm(eval_dataloader): + texts = prepare_texts(texts, conv_temp) # warp the texts with conversation template + answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) + + for answer, label in zip(answers, labels): + result = dict() + result['pred'] = answer.lower().replace('','').strip() + result['gt'] = label + minigpt4_predict.append(result) + if answer.lower() == label: + count+=1 + total+=1 + print('gqa val:', count / total * 100, flush=True) + + file_save_path = os.path.join(save_path, "gqa.json") + with open(file_save_path,'w') as f: + json.dump(minigpt4_predict, f) + +if 'vsr' in args.dataset: + + img_path = cfg.evaluation_datasets_cfg["vsr"]["img_path"] + batch_size = cfg.evaluation_datasets_cfg["vsr"]["batch_size"] + max_new_tokens = cfg.evaluation_datasets_cfg["vsr"]["max_new_tokens"] + from datasets import load_dataset + + annotation = load_dataset("cambridgeltl/vsr_zeroshot", split='test') + data = VSREvalData(annotation, vis_processor, img_path) + eval_dataloader = DataLoader(data, batch_size=batch_size, shuffle=False) + count=0 + total=0 + + minigpt4_predict = [] + + for samples in tqdm(eval_dataloader): + + texts = samples['q_input'] + labels = samples['gt_ans'] + image_ids = samples['image_id'] + + answers = model.predict_answers( + samples=samples, + inference_method=inference_method, + num_beams=num_beams, + max_len=max_len, + min_len=min_len, + num_ans_candidates=num_ans_candidates, + prompt=prompt, + ) + + # answers = model.generate(images, texts, max_new_tokens=max_new_tokens, do_sample=False) + + for answer, label in zip(answers, labels): + result = dict() + result['pred'] = answer.replace('','').strip() + result['gt'] = label + minigpt4_predict.append(result) + if answer.lower() == label.lower(): + count+=1 + total+=1 + print('vsr test:', count / total * 100, flush=True) + # file_save_path = os.path.join(save_path,"vsr.json") + # with open(file_save_path,'w') as f: + # json.dump(minigpt4_predict, f) + +if 'hm' in args.dataset: + + eval_file_path = cfg.evaluation_datasets_cfg["hm"]["eval_file_path"] + img_path = cfg.evaluation_datasets_cfg["hm"]["img_path"] + batch_size = cfg.evaluation_datasets_cfg["hm"]["batch_size"] + max_new_tokens = cfg.evaluation_datasets_cfg["hm"]["max_new_tokens"] + + annotation = [] + with open(eval_file_path, 'r') as jsonl_file: + for line in jsonl_file: + json_obj = json.loads(line) + annotation.append(json_obj) + + data = HMEvalData(annotation, vis_processor, img_path) + eval_dataloader = DataLoader(data, batch_size=20, shuffle=False) + count=0 + total=0 + + predict = [] + + for samples in tqdm(eval_dataloader): + samples['image'] = samples['image'].half().to(device) + texts = samples['q_input'] + labels = samples['gt_ans'] + + answers = model.predict_answers( + samples=samples, + inference_method=inference_method, + num_beams=num_beams, + max_len=max_len, + min_len=min_len, + num_ans_candidates=num_ans_candidates, + prompt=prompt, + ) + + for answer, label in zip(answers, labels): + result = dict() + if answer.lower().strip() =="yes": + answer=1 + elif answer.lower().strip()=="no": + answer=0 + else: + print("non-matching answer",answer) + + result['pred'] = answer + result['gt'] = int(label) + predict.append(result) + if answer == label: + count+=1 + total+=1 + print(answers) + + print('hm val:', count / total * 100, flush=True) + file_save_path = os.path.join(save_path, "hm.json") + with open(file_save_path,'w') as f: + json.dump(predict, f) diff --git a/minigpt4/models/Qformer.py b/minigpt4/models/Qformer.py index e71b123..d26226f 100644 --- a/minigpt4/models/Qformer.py +++ b/minigpt4/models/Qformer.py @@ -17,7 +17,6 @@ from typing import Optional, Tuple, Dict, Any import torch from torch import Tensor, device, dtype, nn import torch.utils.checkpoint -from torch import nn from torch.nn import CrossEntropyLoss import torch.nn.functional as F @@ -45,8 +44,10 @@ from transformers.modeling_utils import ( from transformers.utils import logging from transformers.models.bert.configuration_bert import BertConfig +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... logger = logging.get_logger(__name__) +# from visualizer import get_local class BertEmbeddings(nn.Module): """Construct the embeddings from word and position embeddings.""" @@ -120,17 +121,17 @@ class BertSelfAttention(nn.Module): "heads (%d)" % (config.hidden_size, config.num_attention_heads) ) - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 - self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) if is_cross_attention: - self.key = nn.Linear(config.encoder_width, self.all_head_size) - self.value = nn.Linear(config.encoder_width, self.all_head_size) + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) else: - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) self.position_embedding_type = getattr( @@ -162,7 +163,7 @@ class BertSelfAttention(nn.Module): new_x_shape = x.size()[:-1] + ( self.num_attention_heads, self.attention_head_size, - ) + ) # torch.Size([1, 257, 12, 64]) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) @@ -183,8 +184,8 @@ class BertSelfAttention(nn.Module): is_cross_attention = encoder_hidden_states is not None if is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) attention_mask = encoder_attention_mask elif past_key_value is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) @@ -197,9 +198,9 @@ class BertSelfAttention(nn.Module): mixed_query_layer = self.query(hidden_states) - query_layer = self.transpose_for_scores(mixed_query_layer) + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) - past_key_value = (key_layer, value_layer) + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -244,6 +245,7 @@ class BertSelfAttention(nn.Module): attention_scores = attention_scores / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. @@ -255,17 +257,17 @@ class BertSelfAttention(nn.Module): # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs_dropped = self.dropout(attention_probs) + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) # Mask heads if we want to if head_mask is not None: attention_probs_dropped = attention_probs_dropped * head_mask - context_layer = torch.matmul(attention_probs_dropped, value_layer) + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) outputs = ( (context_layer, attention_probs) if output_attentions else (context_layer,) @@ -361,7 +363,7 @@ class BertIntermediate(nn.Module): return hidden_states -class BertOutput(nn.Module): +class BertOutput(nn.Module): # Add & Norm def __init__(self, config): super().__init__() self.dense = nn.Linear(config.intermediate_size, config.hidden_size) @@ -375,7 +377,7 @@ class BertOutput(nn.Module): return hidden_states -class BertLayer(nn.Module): +class BertLayer(nn.Module): def __init__(self, config, layer_num): super().__init__() self.config = config @@ -492,6 +494,7 @@ class BertEncoder(nn.Module): [BertLayer(config, i) for i in range(config.num_hidden_layers)] ) + # @get_local('all_cross_attentions') def forward( self, hidden_states, @@ -557,7 +560,6 @@ class BertEncoder(nn.Module): output_attentions, query_length, ) - hidden_states = layer_outputs[0] if use_cache: next_decoder_cache += (layer_outputs[-1],) @@ -799,6 +801,7 @@ class BertModel(BertPreTrainedModel): dtype=self.dtype ) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask def forward( diff --git a/minigpt4/models/QformerMoE.py b/minigpt4/models/QformerMoE.py new file mode 100644 index 0000000..7478044 --- /dev/null +++ b/minigpt4/models/QformerMoE.py @@ -0,0 +1,1288 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts, +) +from minigpt4.models.moe.moe_layer import MoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FeedForward(nn.Module): + # remove LayerNorm + def __init__(self, config): + super().__init__() + self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2 + # self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2 + + def forward(self, hidden_states: Tensor): + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts(layer_num) + ffn = FeedForward(config) + if self.use_experts: + self.experts = MoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + route_method=config.moebert_route_method, + topk=config.moe_topk, + use_balance_loss=config.use_balance_loss, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + # add moe query ffn + # query_attention_output size: [bz, query_length+seq_len, 768] + # attention_mask size: [bz, 1, 1, query_length+seq_len] + moe_ffn_attention_input = query_attention_output[:, :query_length, :] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] + cls_hidden=attention_output[:, query_length, :] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, cls_hidden) # layer_output, gate_loss, gate_load + # import pdb; pdb.set_trace() # test0107 + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2]) + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, 0.0, []) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask, cls_hidden=None): + if not self.use_experts: + hidden_states = self.experts(attention_output) + layer_output = self.expert_ln(hidden_states + attention_output) + return layer_output, 0.0, [] + + hidden_states, gate_loss, gate_load = self.experts( + attention_output, expert_attention_mask, cls_hidden + ) + layer_output = self.expert_ln(hidden_states + attention_output) + return layer_output, gate_loss, gate_load + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + gate_loss = 0.0 + gate_loads = list() + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + ) + hidden_states = layer_outputs[0][0] + gate_loss = gate_loss + layer_outputs[0][1] + gate_loads.append(layer_outputs[0][2]) + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + gate_loss=gate_loss, + gate_loads=gate_loads, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + gate_loss=encoder_outputs.gate_loss, + gate_loads=encoder_outputs.gate_loads, + ) + + +class BertMoELMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.load_balance_alpha = config.moebert_load_balance + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + gate_loss * self.load_balance_alpha + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/QformerMoELN.py b/minigpt4/models/QformerMoELN.py new file mode 100644 index 0000000..9ef1f6b --- /dev/null +++ b/minigpt4/models/QformerMoELN.py @@ -0,0 +1,1276 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts, +) +from minigpt4.models.moe.moe_layer import MoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class FeedForward(nn.Module): + # Add LayerNorm + def __init__(self, config): + super().__init__() + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def forward(self, hidden_states: Tensor): + intermediate_output = self.intermediate(hidden_states) + layer_output = self.output(intermediate_output, hidden_states) + return layer_output + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts(layer_num) + ffn = FeedForward(config) + if self.use_experts: + self.experts = MoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + route_method=config.moebert_route_method, + topk=config.moe_topk, + use_balance_loss=config.use_balance_loss, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + # add moe query ffn + # query_attention_output size: [bz, query_length+seq_len, 768] + # attention_mask size: [bz, 1, 1, query_length+seq_len] + moe_ffn_attention_input = query_attention_output[:, :query_length, :] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask) # layer_output, gate_loss, gate_load + # import pdb; pdb.set_trace() # test0107 + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2]) + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, 0.0, []) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask): + if not self.use_experts: + hidden_states = self.experts(attention_output) + return hidden_states, 0.0, [] + + hidden_states, gate_loss, gate_load = self.experts( + attention_output, expert_attention_mask + ) + return hidden_states, gate_loss, gate_load + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + gate_loss = 0.0 + gate_loads = list() + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + ) + hidden_states = layer_outputs[0][0] + gate_loss = gate_loss + layer_outputs[0][1] + gate_loads.append(layer_outputs[0][2]) + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + gate_loss=gate_loss, + gate_loads=gate_loads, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + gate_loss=encoder_outputs.gate_loss, + gate_loads=encoder_outputs.gate_loads, + ) + + +class BertMoELMHeadModelLNIn(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.load_balance_alpha = config.moebert_load_balance + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + gate_loss * self.load_balance_alpha + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/QformerRouteMoE.py b/minigpt4/models/QformerRouteMoE.py new file mode 100644 index 0000000..5cdd983 --- /dev/null +++ b/minigpt4/models/QformerRouteMoE.py @@ -0,0 +1,1374 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +import copy +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts_route, + moe_layer_judge, +) +from minigpt4.models.moe.route_moe_layer import RouteMoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1 + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Move LayerNorm & ResNet out of FFN After MoEFFN + hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1 + return hidden_states + + +class FeedForward(nn.Module): + # remove LayerNorm + def __init__(self, config): + super().__init__() + self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2 + # self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2 + + def forward(self, hidden_states: Tensor): + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts_route(layer_num) + self.layer_judge = moe_layer_judge(layer_num) + self.num_beams = config.moebert_num_beams + ffn = FeedForward(config) + if self.use_experts: + self.experts = RouteMoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = self.layer_judge, + route_method=config.route_method, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + beam_scores=None, + expert_route=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # import pdb; pdb.set_trace() # 0107test + + # adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same + if self.num_beams > 1: + if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams: + # attention_mask dimension to be bz*num_beams + attention_mask = self.adjust_attention_mask(attention_mask) + encoder_attention_mask = self.adjust_attention_mask(encoder_attention_mask) + + if hidden_states.shape[0]*self.num_beams == attention_mask.shape[0]: + # attention_mask dimension back to bz + batch_size = attention_mask.shape[0] + attention_mask = attention_mask[[ i for i in range(0, batch_size, self.num_beams)]] + + if hidden_states.shape[0] == encoder_hidden_states.shape[0]*self.num_beams: + batch_size, visual_tokens, vision_dim = encoder_hidden_states.shape + tmp = encoder_hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, visual_tokens, vision_dim ) + encoder_hidden_states = tmp.contiguous().view(batch_size* self.num_beams, visual_tokens, vision_dim) # torch.Size([bz, 257, 1408]) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + # add moe query ffn + # query_attention_output size: [bz, query_length+seq_len, 768] + # attention_mask size: [bz, 1, 1, query_length+seq_len] + moe_ffn_attention_input = query_attention_output[:, :query_length, :] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route) + # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) + # import pdb; pdb.set_trace() # 0107test + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + if self.layer_judge == 'first' and self.num_beams>1: + # if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1: + # adjust the dimension of layer_output_text to bz*num_beams + layer_output_text = self.adjust_hidden_states_by_num_beams(layer_output_text) + + if self.layer_judge == 'mid' and self.num_beams > 1: + # layer_output_text [bz*num_beams, len, hidden_size] + beam_idx = layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + if self.layer_judge == 'last' and self.num_beams>1: + # select top1 for each sample among beams + # layer_output = (hidden_states, beam_scores, expert_route) + # layer_output & layer_output_text dimen_0 from bz*num_beams to bz + layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text) + + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4]) + # import pdb; pdb.set_trace() # 0107test + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, None, None, None, 0.0) + + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def adjust_attention_mask(self, attention_mask): + batch_size = attention_mask.shape[0] + tmp = attention_mask.unsqueeze(1).expand(batch_size, self.num_beams, 1, 1, attention_mask.shape[3]) + attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len]) + return attention_mask + + def adjust_hidden_states_by_num_beams(self, hidden_states): + batch_size, text_length, hidden_size = hidden_states.shape + tmp_text = hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size) + hidden_states = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768] + return hidden_states + + def route_moe_last_layer_top1(self, layer_output, layer_output_text): + batch_size = layer_output[0].shape[0] + raw_batch_size = int(batch_size / self.num_beams) + hidden_states, beam_scores, expert_route, beam_idx = layer_output[0], layer_output[1], layer_output[2], layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + scores = beam_scores.view(raw_batch_size, self.num_beams) + _, gate = torch.topk(scores, 1, dim=1) + selects = [ (bz_idx * self.num_beams + gate[bz_idx].item()) for bz_idx in range(raw_batch_size)] + + layer_output_text = layer_output_text[selects] + hidden_states_new = hidden_states[selects] + beam_scores_new = beam_scores[selects] + expert_route_new = expert_route[selects] + + return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text + + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route): + if not self.use_experts: + hidden_states = self.experts(attention_output) + layer_output = self.expert_ln(hidden_states + attention_output) + return layer_output, None, None, None, 0.0 + + hidden_states, beam_scores, expert_route, beam_idx, importance_loss = self.experts( + attention_output, expert_attention_mask, beam_scores, expert_route + ) + if hidden_states.shape[0]==attention_output.shape[0]*self.num_beams and self.num_beams>1: + attention_output = self.adjust_hidden_states_by_num_beams(attention_output) + layer_output = self.expert_ln(hidden_states + attention_output) + + return layer_output, beam_scores, expert_route, beam_idx, importance_loss + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + beam_scores=None + expert_route=None + importance_loss = 0 + for i in range(self.config.num_hidden_layers): + + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length, beam_scores, expert_route + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + beam_scores, # None + expert_route, # None + ) + hidden_states = layer_outputs[0][0] + beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1] + expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2] + importance_loss += layer_outputs[0][4] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + beam_scores=beam_scores, + expert_route=expert_route, + gate_loss=importance_loss, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + beam_scores=encoder_outputs.beam_scores, + expert_route=encoder_outputs.expert_route, + gate_loss=encoder_outputs.gate_loss + ) + + +class BertMoERouteLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + # gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/QformerRouteMoELN.py b/minigpt4/models/QformerRouteMoELN.py new file mode 100644 index 0000000..62473fd --- /dev/null +++ b/minigpt4/models/QformerRouteMoELN.py @@ -0,0 +1,1367 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +import copy +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts_route, + moe_layer_judge, +) +from minigpt4.models.moe.route_moe_layer import RouteMoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1 + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Move LayerNorm & ResNet out of FFN After MoEFFN + hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1 + return hidden_states + + +class FeedForward(nn.Module): + def __init__(self, config): + nn.Module.__init__(self) + # first layer + self.intermediate_query = BertIntermediate(config) + # second layer + self.output_query = BertOutput(config) + + def forward(self, hidden_states: Tensor): + input_tensor = hidden_states + intermediate_output = self.intermediate_query(hidden_states) + hidden_states = self.output_query(intermediate_output, input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts_route(layer_num) + self.layer_judge = moe_layer_judge(layer_num) + self.num_beams = config.moebert_num_beams + ffn = FeedForward(config) + + if self.use_experts: + self.experts = RouteMoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = self.layer_judge, + route_method=config.route_method, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + beam_scores=None, + expert_route=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # import pdb; pdb.set_trace() # 0107test + + # adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same + if self.num_beams > 1: + if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams: + # attention_mask dimension to be bz*num_beams + attention_mask = self.adjust_attention_mask(attention_mask) + encoder_attention_mask = self.adjust_attention_mask(encoder_attention_mask) + + if hidden_states.shape[0]*self.num_beams == attention_mask.shape[0]: + # attention_mask dimension back to bz + batch_size = attention_mask.shape[0] + attention_mask = attention_mask[[ i for i in range(0, batch_size, self.num_beams)]] + + if hidden_states.shape[0] == encoder_hidden_states.shape[0]*self.num_beams: + batch_size, visual_tokens, vision_dim = encoder_hidden_states.shape + tmp = encoder_hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, visual_tokens, vision_dim ) + encoder_hidden_states = tmp.contiguous().view(batch_size* self.num_beams, visual_tokens, vision_dim) # torch.Size([bz, 257, 1408]) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + # add moe query ffn + # query_attention_output size: [bz, query_length+seq_len, 768] + # attention_mask size: [bz, 1, 1, query_length+seq_len] + moe_ffn_attention_input = query_attention_output[:, :query_length, :] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route) + # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) + # import pdb; pdb.set_trace() # 0107test + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + if self.layer_judge == 'first' and self.num_beams>1: + # if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1: + # adjust the dimension of layer_output_text to bz*num_beams + layer_output_text = self.adjust_layer_output_text(layer_output_text) + + if self.layer_judge == 'mid' and self.num_beams > 1: + # layer_output_text [bz*num_beams, len, hidden_size] + beam_idx = layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + if self.layer_judge == 'last' and self.num_beams>1: + # select top1 for each sample among beams + # layer_output = (hidden_states, beam_scores, expert_route) + # layer_output & layer_output_text dimen_0 from bz*num_beams to bz + layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text) + + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4]) + # import pdb; pdb.set_trace() # 0107test + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, None, None, None, 0.0) + + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def adjust_attention_mask(self, attention_mask): + batch_size = attention_mask.shape[0] + tmp = attention_mask.unsqueeze(1).expand(batch_size, self.num_beams, 1, 1, attention_mask.shape[3]) + attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len]) + return attention_mask + + def adjust_layer_output_text(self, layer_output_text): + batch_size, text_length, hidden_size = layer_output_text.shape + tmp_text = layer_output_text.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size) + layer_output_text = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768] + return layer_output_text + + def route_moe_last_layer_top1(self, layer_output, layer_output_text): + batch_size = layer_output[0].shape[0] + raw_batch_size = int(batch_size / self.num_beams) + hidden_states, beam_scores, expert_route, beam_idx = layer_output[0], layer_output[1], layer_output[2], layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + scores = beam_scores.view(raw_batch_size, self.num_beams) + _, gate = torch.topk(scores, 1, dim=1) + selects = [ (bz_idx * self.num_beams + gate[bz_idx].item()) for bz_idx in range(raw_batch_size)] + + layer_output_text = layer_output_text[selects] + hidden_states_new = hidden_states[selects] + beam_scores_new = beam_scores[selects] + expert_route_new = expert_route[selects] + + return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text + + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + # layer_output = self.LayerNorm(layer_output + attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route): + if not self.use_experts: + layer_output = self.experts(attention_output) + # layer_output = self.LayerNorm(layer_output + attention_output) + return layer_output, None, None, None, 0.0 + + layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts( + attention_output, expert_attention_mask, beam_scores, expert_route + ) + + # layer_output = self.LayerNorm(layer_output + attention_output) + return layer_output, beam_scores, expert_route, beam_idx, importance_loss + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + beam_scores=None + expert_route=None + importance_loss = 0 + for i in range(self.config.num_hidden_layers): + + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length, beam_scores, expert_route + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + beam_scores, # None + expert_route, # None + ) + hidden_states = layer_outputs[0][0] + beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1] + expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2] + importance_loss += layer_outputs[0][4] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + beam_scores=beam_scores, + expert_route=expert_route, + gate_loss=importance_loss, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + beam_scores=encoder_outputs.beam_scores, + expert_route=encoder_outputs.expert_route, + gate_loss=encoder_outputs.gate_loss + ) + + +class BertMoERouteLMHeadModelLNIn(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + # gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/QformerRouteMoELNUni.py b/minigpt4/models/QformerRouteMoELNUni.py new file mode 100644 index 0000000..6fbbd06 --- /dev/null +++ b/minigpt4/models/QformerRouteMoELNUni.py @@ -0,0 +1,1383 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +import copy +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts_route, + moe_layer_judge, +) +from minigpt4.models.moe.uniroute_moe_layer import UniRouteMoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1 + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Move LayerNorm & ResNet out of FFN After MoEFFN + hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1 + return hidden_states + + +class FeedForward(nn.Module): + def __init__(self, config): + nn.Module.__init__(self) + # first layer + self.intermediate_query = BertIntermediate(config) + # second layer + self.output_query = BertOutput(config) + + def forward(self, hidden_states: Tensor): + input_tensor = hidden_states + intermediate_output = self.intermediate_query(hidden_states) + hidden_states = self.output_query(intermediate_output, input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts_route(layer_num) + self.layer_judge = moe_layer_judge(layer_num) + self.num_beams = config.moebert_num_beams + ffn = FeedForward(config) + + if self.use_experts: + self.experts = UniRouteMoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = self.layer_judge, + route_method=config.route_method, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + beam_scores=None, + expert_route=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # import pdb; pdb.set_trace() # 0107test + + # adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same + if self.num_beams > 1: + if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams: + # attention_mask dimension to be bz*num_beams + attention_mask = self.adjust_attention_mask(attention_mask) + encoder_attention_mask = self.adjust_attention_mask(encoder_attention_mask) + + if hidden_states.shape[0]*self.num_beams == attention_mask.shape[0]: + # attention_mask dimension back to bz + batch_size = attention_mask.shape[0] + attention_mask = attention_mask[[ i for i in range(0, batch_size, self.num_beams)]] + + if hidden_states.shape[0] == encoder_hidden_states.shape[0]*self.num_beams: + batch_size, visual_tokens, vision_dim = encoder_hidden_states.shape + tmp = encoder_hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, visual_tokens, vision_dim ) + encoder_hidden_states = tmp.contiguous().view(batch_size* self.num_beams, visual_tokens, vision_dim) # torch.Size([bz, 257, 1408]) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + # add moe query ffn + # query_attention_output size: [bz, query_length+seq_len, 768] + # attention_mask size: [bz, 1, 1, query_length+seq_len] + moe_ffn_attention_input = query_attention_output[:, :query_length, :] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route) + # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) + import pdb; pdb.set_trace() # 0107test + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + if self.layer_judge == 'first' and self.num_beams>1: + # if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1: + # adjust the dimension of layer_output_text to bz*num_beams + layer_output_text = self.adjust_layer_output_text(layer_output_text) + + if self.layer_judge == 'mid' and self.num_beams > 1: + # layer_output_text [bz*num_beams, len, hidden_size] + beam_idx = layer_output[3] + bz = layer_output[0].shape[0] / self.num_beams + num_expert_beam = self.num_beams-1 + select_idx = list() + for i in range(bz): + for k in range(num_expert_beam): + raw_beam_idx = beam_idx[i*num_expert_beam+k].item() + select_idx += [raw_beam_idx+i] + select_idx += [(i*self.num_beams)+num_expert_beam] + layer_output_text = layer_output_text[select_idx] + + if self.layer_judge == 'last' and self.num_beams>1: + # combine output of universal expert and route experts + # layer_output = (hidden_states, beam_scores, expert_route) + # layer_output & layer_output_text dimen_0 from bz*num_beams to bz + layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text) + + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4]) + # import pdb; pdb.set_trace() # 0107test + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, None, None, None, 0.0) + + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def adjust_attention_mask(self, attention_mask): + batch_size = attention_mask.shape[0] + tmp = attention_mask.unsqueeze(1).expand(batch_size, self.num_beams, 1, 1, attention_mask.shape[3]) + attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len]) + return attention_mask + + def adjust_layer_output_text(self, layer_output_text): + batch_size, text_length, hidden_size = layer_output_text.shape + tmp_text = layer_output_text.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size) + layer_output_text = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768] + return layer_output_text + + def route_moe_last_layer_top1(self, layer_output, layer_output_text): + batch_size = int(layer_output[0].shape[0] / self.num_beams) + hidden_states, beam_scores, expert_route, beam_idx = layer_output[0], layer_output[1], layer_output[2], layer_output[3] + num_route_experts = self.num_beams-1 + + ### universal expert + select_universal = [i*self.num_beams+num_route_experts for i in range(batch_size)] + universal_output = hidden_states[select_universal] # [bz, 32, 768] + layer_output_text_universal = layer_output_text[select_universal] # [bz, 32, 768] + + ## route expert + select_expert = [ x for x in range(batch_size*self.num_beams) if x not in select_universal] # length = bz * num_route_experts + route_expert_output = hidden_states[select_expert] # [bz*num_route_experts, 32, 768] + scores = beam_scores.view(batch_size, num_route_experts) + _, gate = torch.topk(scores, 1, dim=1) + selects = [ (bz_idx * num_route_experts + gate[bz_idx].item()) for bz_idx in range(batch_size)] + layer_output_text_route = layer_output_text[select_expert][beam_idx] + + layer_output_text_route = layer_output_text_route[selects] + route_expert_hidden = route_expert_output[selects] + beam_scores_new = beam_scores[selects] + expert_route_new = expert_route[selects] + + # combine universal & route experts + final_hidden_states = universal_output + route_expert_hidden + final_layer_text = layer_output_text_universal + layer_output_text_route + + return (final_hidden_states, beam_scores_new, expert_route_new, beam_idx, layer_output[4]), final_layer_text + + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route): + if not self.use_experts: + layer_output = self.experts(attention_output) + return layer_output, None, None, None, 0.0 + + layer_output, beam_scores, expert_route, beam_idx, importance_loss = self.experts( + attention_output, expert_attention_mask, beam_scores, expert_route + ) + return layer_output, beam_scores, expert_route, beam_idx, importance_loss + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + beam_scores=None + expert_route=None + importance_loss = 0 + for i in range(self.config.num_hidden_layers): + + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length, beam_scores, expert_route + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + beam_scores, # None + expert_route, # None + ) + hidden_states = layer_outputs[0][0] + beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1] + expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2] + importance_loss += layer_outputs[0][4] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + beam_scores=beam_scores, + expert_route=expert_route, + gate_loss=importance_loss, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + beam_scores=encoder_outputs.beam_scores, + expert_route=encoder_outputs.expert_route, + gate_loss=encoder_outputs.gate_loss + ) + + +class BertMoERouteLMHeadModelLNInUniversal(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + # gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/QformerRouteMoEUni.py b/minigpt4/models/QformerRouteMoEUni.py new file mode 100644 index 0000000..cb29d41 --- /dev/null +++ b/minigpt4/models/QformerRouteMoEUni.py @@ -0,0 +1,1394 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +import copy +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts_route, + moe_layer_judge, +) +from minigpt4.models.moe.uniroute_moe_layer import UniRouteMoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1 + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Move LayerNorm & ResNet out of FFN After MoEFFN + hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1 + return hidden_states + + +class FeedForward(nn.Module): + # remove LayerNorm + def __init__(self, config): + super().__init__() + self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2 + # self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2 + + def forward(self, hidden_states: Tensor): + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts_route(layer_num) + self.layer_judge = moe_layer_judge(layer_num) + self.num_beams = config.moebert_num_beams + ffn = FeedForward(config) + if self.use_experts: + self.experts = UniRouteMoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = self.layer_judge, + route_method=config.route_method, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + beam_scores=None, + expert_route=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # import pdb; pdb.set_trace() # 0107test + + # adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same + if self.num_beams > 1: + if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams: + # attention_mask dimension to be bz*num_beams + attention_mask = self.adjust_attention_mask(attention_mask) + encoder_attention_mask = self.adjust_attention_mask(encoder_attention_mask) + + if hidden_states.shape[0]*self.num_beams == attention_mask.shape[0]: + # attention_mask dimension back to bz + batch_size = attention_mask.shape[0] + attention_mask = attention_mask[[ i for i in range(0, batch_size, self.num_beams)]] + + if hidden_states.shape[0] == encoder_hidden_states.shape[0]*self.num_beams: + batch_size, visual_tokens, vision_dim = encoder_hidden_states.shape + tmp = encoder_hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, visual_tokens, vision_dim ) + encoder_hidden_states = tmp.contiguous().view(batch_size* self.num_beams, visual_tokens, vision_dim) # torch.Size([bz, 257, 1408]) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + # add moe query ffn + # query_attention_output size: [bz, query_length+seq_len, 768] + # attention_mask size: [bz, 1, 1, query_length+seq_len] + moe_ffn_attention_input = query_attention_output[:, :query_length, :] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route) + # layer_output = (layer_output, beam_scores, expert_route, beam_idx, importance_loss) + # import pdb; pdb.set_trace() # 0107test + + if attention_output.shape[1] > query_length: # have text input in Qformer + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + if self.layer_judge == 'first' and self.num_beams>1: + # if layer_output[0].shape[0] == layer_output_text.shape[0]*self.num_beams and self.num_beams>1: + # adjust the dimension of layer_output_text to bz*num_beams + layer_output_text = self.adjust_hidden_states_by_num_beams(layer_output_text) + + if self.layer_judge == 'mid' and self.num_beams > 1: + # layer_output_text [bz*num_beams, len, hidden_size] + beam_idx = layer_output[3] + bz = int(layer_output[0].shape[0] / self.num_beams) + num_expert_beam = self.num_beams-1 + select_idx = list() + for i in range(bz): + for k in range(num_expert_beam): + raw_beam_idx = beam_idx[i*num_expert_beam+k].item() + select_idx += [raw_beam_idx+i] + select_idx += [(i*self.num_beams)+num_expert_beam] + layer_output_text = layer_output_text[select_idx] + + if self.layer_judge == 'last' and self.num_beams>1: + # select top1 for each sample among beams + # layer_output = (hidden_states, beam_scores, expert_route) + # layer_output & layer_output_text dimen_0 from bz*num_beams to bz + layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text) + + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4]) + # import pdb; pdb.set_trace() # 0107test + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, None, None, None, 0.0) + + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def adjust_attention_mask(self, attention_mask): + batch_size = attention_mask.shape[0] + tmp = attention_mask.unsqueeze(1).expand(batch_size, self.num_beams, 1, 1, attention_mask.shape[3]) + attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len]) + return attention_mask + + def adjust_hidden_states_by_num_beams(self, hidden_states): + batch_size, text_length, hidden_size = hidden_states.shape + tmp_text = hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size) + hidden_states = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768] + return hidden_states + + def route_moe_last_layer_top1(self, layer_output, layer_output_text): + batch_size = int(layer_output[0].shape[0] / self.num_beams) + hidden_states, beam_scores, expert_route, beam_idx = layer_output[0], layer_output[1], layer_output[2], layer_output[3] + num_route_experts = self.num_beams-1 + + ### universal expert + select_universal = [i*self.num_beams+num_route_experts for i in range(batch_size)] + universal_output = hidden_states[select_universal] # [bz, 32, 768] + layer_output_text_universal = layer_output_text[select_universal] # [bz, 32, 768] + + ## route expert + select_expert = [ x for x in range(batch_size*self.num_beams) if x not in select_universal] # length = bz * num_route_experts + route_expert_output = hidden_states[select_expert] # [bz*num_route_experts, 32, 768] + scores = beam_scores.view(batch_size, num_route_experts) + _, gate = torch.topk(scores, 1, dim=1) + selects = [ (bz_idx * num_route_experts + gate[bz_idx].item()) for bz_idx in range(batch_size)] + layer_output_text_route = layer_output_text[select_expert][beam_idx] + + layer_output_text_route = layer_output_text_route[selects] + route_expert_hidden = route_expert_output[selects] + beam_scores_new = beam_scores[selects] + expert_route_new = expert_route[selects] + + # combine universal & route experts + final_hidden_states = universal_output + route_expert_hidden + final_layer_text = layer_output_text_universal + layer_output_text_route + # import pdb; pdb.set_trace() # 0107test + + return (final_hidden_states, beam_scores_new, expert_route_new, beam_idx, layer_output[4]), final_layer_text + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route): + if not self.use_experts: + hidden_states = self.experts(attention_output) + layer_output = self.expert_ln(hidden_states + attention_output) + return layer_output, None, None, None, 0.0 + + hidden_states, beam_scores, expert_route, beam_idx, importance_loss = self.experts( + attention_output, expert_attention_mask, beam_scores, expert_route + ) + if hidden_states.shape[0]==attention_output.shape[0]*self.num_beams and self.num_beams>1: + attention_output = self.adjust_hidden_states_by_num_beams(attention_output) + layer_output = self.expert_ln(hidden_states + attention_output) + + return layer_output, beam_scores, expert_route, beam_idx, importance_loss + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + beam_scores=None + expert_route=None + importance_loss = 0 + for i in range(self.config.num_hidden_layers): + + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length, beam_scores, expert_route + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + beam_scores, # None + expert_route, # None + ) + hidden_states = layer_outputs[0][0] + beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1] + expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2] + importance_loss += layer_outputs[0][4] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + beam_scores=beam_scores, + expert_route=expert_route, + gate_loss=importance_loss, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + beam_scores=encoder_outputs.beam_scores, + expert_route=encoder_outputs.expert_route, + gate_loss=encoder_outputs.gate_loss + ) + + +class BertMoERouteLMHeadModelUniversal(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + # gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/QformerRouteMoEUniback.py b/minigpt4/models/QformerRouteMoEUniback.py new file mode 100644 index 0000000..06fb4d3 --- /dev/null +++ b/minigpt4/models/QformerRouteMoEUniback.py @@ -0,0 +1,1322 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" + +import math +import os +import warnings +import copy +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import logging +from transformers.models.bert.configuration_bert import BertConfig + +from minigpt4.models.moe.utils import ( + MoEModelOutput, + MoEModelOutputWithPooling, + use_experts_route, + moe_layer_judge, +) +from minigpt4.models.moe.route_moe_layer import RouteMoELayer + +logging.set_verbosity_error() # ignore warning : Some weights of BertLMHeadModel were not initialized from the model checkpoint... +logger = logging.get_logger(__name__) + +# from visualizer import get_local + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads # 12 + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) # 64 + self.all_head_size = self.num_attention_heads * self.attention_head_size # 768 + + self.query = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + self.value = nn.Linear(config.encoder_width, self.all_head_size) # nn.Linear(1408, 768) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + self.value = nn.Linear(config.hidden_size, self.all_head_size) # nn.Linear(768, 768) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) # torch.Size([1, 257, 12, 64]) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) # encoder_hidden_states:[bz,257,1408], torch.Size([1, 12, 257, 64]) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) # torch.Size([1, 12, 257, 64]) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) # torch.Size([1, 12, 41, 64]) + + past_key_value = (key_layer, value_layer) # torch.Size([1, 12, 41, 257]) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + # extended_attention_mask + + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) # torch.Size([1, 12, 41, 257]) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) # torch.Size([1, 12, 41, 64]) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) # torch.Size([1, 41, 768]) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class BertSelfOutput(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): # Add & Norm + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # 1 + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + # Move LayerNorm & ResNet out of FFN After MoEFFN + hidden_states = self.LayerNorm(hidden_states + input_tensor) # 1 + return hidden_states + + +class FeedForward(nn.Module): + # remove LayerNorm + def __init__(self, config): + super().__init__() + self.dense1 = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + self.dense2 = nn.Linear(config.intermediate_size, config.hidden_size) + self.dropout = nn.Dropout(config.hidden_dropout_prob) # adjust dropout ratio 0.1->0.2 + # self.dropout = nn.Dropout(0.2) # adjust dropout ratio 0.1->0.2 + + def forward(self, hidden_states: Tensor): + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class UniFeedForward(nn.Module): + def __init__(self, config): + nn.Module.__init__(self) + # first layer + self.intermediate_query = BertIntermediate(config) + # second layer + self.output_query = BertOutput(config) + + def forward(self, hidden_states: Tensor): + input_tensor = hidden_states + intermediate_output = self.intermediate_query(hidden_states) + hidden_states = self.output_query(intermediate_output, input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = BertAttention(config) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config) + + # Add MoE FFN + self.use_experts = use_experts_route(layer_num) + self.layer_judge = moe_layer_judge(layer_num) + self.num_beams = config.moebert_num_beams + ffn = FeedForward(config) + if self.use_experts: + self.experts = RouteMoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = self.layer_judge, + route_method=config.route_method, + weight_type=config.moe_weight_type, + ) + else: + self.experts = ffn + self.expert_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.universal_expert = UniFeedForward(config) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + beam_scores=None, + expert_route=None, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + # import pdb; pdb.set_trace() # 0107test + + # adjust the dimension of hidden_states, attention_mask, encoder_attention_mask and encoder_hidden_states to be the same + if self.num_beams > 1: + if hidden_states.shape[0]== attention_mask.shape[0]*self.num_beams: + # attention_mask dimension to be bz*num_beams + attention_mask = self.adjust_attention_mask(attention_mask) + encoder_attention_mask = self.adjust_attention_mask(encoder_attention_mask) + + if hidden_states.shape[0]*self.num_beams == attention_mask.shape[0]: + # attention_mask dimension back to bz + batch_size = attention_mask.shape[0] + attention_mask = attention_mask[[ i for i in range(0, batch_size, self.num_beams)]] + + if hidden_states.shape[0] == encoder_hidden_states.shape[0]*self.num_beams: + batch_size, visual_tokens, vision_dim = encoder_hidden_states.shape + tmp = encoder_hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, visual_tokens, vision_dim ) + encoder_hidden_states = tmp.contiguous().view(batch_size* self.num_beams, visual_tokens, vision_dim) # torch.Size([bz, 257, 1408]) + + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + + if attention_output.shape[1] > query_length: # have text input in Qformer + + + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + + if self.layer_judge == 'first' and self.num_beams>1: + bz = query_attention_output.shape[0] + moe_ffn_attention_input = query_attention_output[:, :query_length, :] # [bz, 32, hidden_dim] + moe_ffn_attention_mask = attention_mask.squeeze(dim=1).squeeze(dim=1)[:, :query_length] + layer_output = self.feed_forward_query_moe(moe_ffn_attention_input, moe_ffn_attention_mask, beam_scores, expert_route) # [bz*num_beam, 32, hidden_dim] + expert_output_hidden = layer_output[0] + universal_output = self.universal_expert(moe_ffn_attention_input)# [bz, 32, hidden_dim] + lst = list() + for i in range(bz): + lst.append(universal_output[i,:,:]) + lst.append(expert_output_hidden[i*self.num_beams:(i+1)*self.num_beams,:,:]) + uni_exp_output = torch.stack(lst) # [bz*(num_beam+1), 32, 768] + layer_output_text = self.adjust_hidden_states_by_num_beams(layer_output_text, add_universal=True) #[bz*(num_beam+1), len, hidden_dim] + layer_output = (torch.cat([uni_exp_output, layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4]) + # layer_output = (hidden_states, beam_scores, expert_route, beam_idx, importance_loss) + + if self.layer_judge == 'mid' and self.num_beams > 1: + # layer_output_text [bz*num_beams, len, hidden_size] + bz = query_attention_output.shape[0] / (self.num_beams+1) + + + + + beam_idx = layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + if self.layer_judge == 'last' and self.num_beams>1: + # select top1 for each sample among beams + # layer_output = (hidden_states, beam_scores, expert_route) + # layer_output & layer_output_text dimen_0 from bz*num_beams to bz + layer_output, layer_output_text = self.route_moe_last_layer_top1(layer_output, layer_output_text) + + layer_output = (torch.cat([layer_output[0], layer_output_text], dim=1), layer_output[1], layer_output[2], layer_output[3],layer_output[4]) + # import pdb; pdb.set_trace() # 0107test + + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + layer_output = (layer_output, None, None, None, 0.0) + + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def adjust_attention_mask(self, attention_mask): + batch_size = attention_mask.shape[0] + tmp = attention_mask.unsqueeze(1).expand(batch_size, self.num_beams, 1, 1, attention_mask.shape[3]) + attention_mask = tmp.contiguous().view(batch_size* self.num_beams, 1, 1, attention_mask.shape[3]) # torch.Size([bz*num_beams, 1, 1, 32+input_len]) + return attention_mask + + def adjust_hidden_states_by_num_beams(self, hidden_states, add_universal=False): + batch_size, text_length, hidden_size = hidden_states.shape + if add_universal: + tmp_text = hidden_states.unsqueeze(1).expand(batch_size, self.num_beams+1, text_length, hidden_size) + else: + tmp_text = hidden_states.unsqueeze(1).expand(batch_size, self.num_beams, text_length, hidden_size) + hidden_states = tmp_text.contiguous().view(-1, text_length, hidden_size) # [bz*num_beams, text_length ,768] + return hidden_states + + def route_moe_last_layer_top1(self, layer_output, layer_output_text): + batch_size = layer_output[0].shape[0] + raw_batch_size = int(batch_size / self.num_beams) + hidden_states, beam_scores, expert_route, beam_idx = layer_output[0], layer_output[1], layer_output[2], layer_output[3] + layer_output_text = layer_output_text[beam_idx] + + scores = beam_scores.view(raw_batch_size, self.num_beams) + _, gate = torch.topk(scores, 1, dim=1) + selects = [ (bz_idx * self.num_beams + gate[bz_idx].item()) for bz_idx in range(raw_batch_size)] + + layer_output_text = layer_output_text[selects] + hidden_states_new = hidden_states[selects] + beam_scores_new = beam_scores[selects] + expert_route_new = expert_route[selects] + + return (hidden_states_new, beam_scores_new, expert_route_new, layer_output[3], layer_output[4]), layer_output_text + + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_query_moe(self, attention_output, expert_attention_mask, beam_scores, expert_route): + if not self.use_experts: + hidden_states = self.experts(attention_output) + layer_output = self.expert_ln(hidden_states + attention_output) + return layer_output, None, None, None, 0.0 + + input_tensor = attention_output + hidden_states, beam_scores, expert_route, beam_idx, importance_loss = self.experts( + attention_output, expert_attention_mask, beam_scores, expert_route + ) + if hidden_states.shape[0]==attention_output.shape[0]*self.num_beams and self.num_beams>1: + attention_output = self.adjust_hidden_states_by_num_beams(attention_output) + experts_output = self.expert_ln(hidden_states + attention_output) # [bz*num_beams, 32, hiddem_dim] + + universal_output = self.universal_expert(input_tensor) + layer_output = torch.concat((universal_output, experts_output)) + + return layer_output, beam_scores, expert_route, beam_idx, importance_loss + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + # @get_local('all_cross_attentions') + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + beam_scores=None + expert_route=None + importance_loss = 0 + for i in range(self.config.num_hidden_layers): + + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length, beam_scores, expert_route + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, #torch.Size([bz, 32+input_len, 768]) + attention_mask, # torch.Size([bz, 1, 1, 32+input_len]) + layer_head_mask, # None + encoder_hidden_states, # torch.Size([bz, 257, 1408]) + encoder_attention_mask, + past_key_value, + output_attentions, # False + query_length, # 32 + beam_scores, # None + expert_route, # None + ) + hidden_states = layer_outputs[0][0] + beam_scores = beam_scores if layer_outputs[0][1] == None else layer_outputs[0][1] + expert_route = expert_route if layer_outputs[0][2] == None else layer_outputs[0][2] + importance_loss += layer_outputs[0][4] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + + return MoEModelOutput( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + beam_scores=beam_scores, + expert_route=expert_route, + gate_loss=importance_loss, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return MoEModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + beam_scores=encoder_outputs.beam_scores, + expert_route=encoder_outputs.expert_route, + gate_loss=encoder_outputs.gate_loss + ) + + +class BertMoERouteLMHeadModelUniverse(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + # gate_loss = outputs.gate_loss + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss, total_loss = None, None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + total_loss = lm_loss + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=total_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + diff --git a/minigpt4/models/__init__.py b/minigpt4/models/__init__.py index bc01b56..6ba2006 100644 --- a/minigpt4/models/__init__.py +++ b/minigpt4/models/__init__.py @@ -14,6 +14,10 @@ from minigpt4.models.base_model import BaseModel from minigpt4.models.minigpt_base import MiniGPTBase from minigpt4.models.minigpt4 import MiniGPT4 from minigpt4.models.minigpt_v2 import MiniGPTv2 +from minigpt4.models.blip2_qformer import Blip2Qformer +from minigpt4.models.blip2_t5_instruct_pro_moe import Blip2T5InstructPromptMOE +from minigpt4.models.blip2_t5_instruct import Blip2T5InstructQformerMoE +from minigpt4.models.blip2_vicuna_instruct import Blip2VicunaInstruct from minigpt4.processors.base_processor import BaseProcessor @@ -22,7 +26,11 @@ __all__ = [ "BaseModel", "MiniGPTBase", "MiniGPT4", - "MiniGPTv2" + "MiniGPTv2", + "Blip2Qformer", + "Blip2T5InstructPromptMOE", + "Blip2T5InstructQformerMoE", + "Blip2VicunaInstruct" ] diff --git a/minigpt4/models/base_model.py b/minigpt4/models/base_model.py index d70ca18..0901603 100644 --- a/minigpt4/models/base_model.py +++ b/minigpt4/models/base_model.py @@ -242,7 +242,51 @@ class LayerNorm(nn.LayerNorm): ret = super().forward(x.type(torch.float32)) return ret.type(orig_type) - +def all_gather_with_grad(tensors): + """ + Performs all_gather operation on the provided tensors. + Graph remains connected for backward grad computation. + """ + # Queue the gathered tensors + world_size = torch.distributed.get_world_size() + # There is no need for reduction in the single-proc case + if world_size == 1: + return tensors + + # tensor_all = GatherLayer.apply(tensors) + tensor_all = GatherLayer.apply(tensors) + + return torch.cat(tensor_all, dim=0) + + +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + # if use distributed training + if not is_dist_avail_and_initialized(): + return tensor + + tensors_gather = [ + torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) + ] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output + + +def tile(x, dim, n_tile): + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]) + ) + return torch.index_select(x, dim, order_index.to(x.device)) diff --git a/minigpt4/models/blip2.py b/minigpt4/models/blip2.py new file mode 100644 index 0000000..51aa1ee --- /dev/null +++ b/minigpt4/models/blip2.py @@ -0,0 +1,483 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import contextlib +import logging +import os +import time +import datetime + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.nn.functional as F + +import minigpt4.common.dist_utils as dist_utils +from minigpt4.common.dist_utils import download_cached_file +from minigpt4.common.utils import is_url +from minigpt4.common.logger import MetricLogger +from minigpt4.models.base_model import BaseModel +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel +from minigpt4.models.QformerMoE import BertMoELMHeadModel +from minigpt4.models.QformerMoELN import BertMoELMHeadModelLNIn +from minigpt4.models.QformerRouteMoE import BertMoERouteLMHeadModel +from minigpt4.models.QformerRouteMoELN import BertMoERouteLMHeadModelLNIn +from minigpt4.models.QformerRouteMoELNUni import BertMoERouteLMHeadModelLNInUniversal +from minigpt4.models.QformerRouteMoEUni import BertMoERouteLMHeadModelUniversal +from minigpt4.models.eva_vit import create_eva_vit_g +from transformers import BertTokenizer +from peft import ( + LoraConfig, + get_peft_model, + prepare_model_for_int8_training, +) + +class Blip2Base(BaseModel): + @classmethod + def init_tokenizer(cls, truncation_side="right"): + tokenizer = BertTokenizer.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", truncation_side=truncation_side) + tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + return tokenizer + + def maybe_autocast(self, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = self.device != torch.device("cpu") + + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2): + encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = cross_attention_freq + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=encoder_config + ) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + + @classmethod + def init_RouteMoEQformerUni(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"): + moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + + moe_encoder_config.encoder_width = vision_width + moe_encoder_config.add_cross_attention = True + moe_encoder_config.cross_attention_freq = cross_attention_freq + moe_encoder_config.query_length = num_query_token + + moe_encoder_config.moebert_expert_num = moebert_expert_num + moe_encoder_config.moebert_num_beams = moebert_num_beams + moe_encoder_config.route_method = route_method + moe_encoder_config.moe_weight_type = moe_weight_type + + if ln_position == "out": + RouteMoEQformer = BertMoERouteLMHeadModelUniversal.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config + ) + elif ln_position == "in": + RouteMoEQformer = BertMoERouteLMHeadModelLNInUniversal.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config + ) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, moe_encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range) + + return RouteMoEQformer, query_tokens + + + @classmethod + def init_RouteMoEQformer(cls, num_query_token, vision_width, moebert_expert_num, moebert_num_beams, route_method, moe_weight_type, cross_attention_freq=2, ln_position="out"): + moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + + moe_encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + moe_encoder_config.add_cross_attention = True + moe_encoder_config.cross_attention_freq = cross_attention_freq + moe_encoder_config.query_length = num_query_token + + moe_encoder_config.moebert_expert_num = moebert_expert_num + moe_encoder_config.moebert_num_beams = moebert_num_beams + moe_encoder_config.route_method = route_method + moe_encoder_config.moe_weight_type = moe_weight_type + + if ln_position == "out": + RouteMoEQformer = BertMoERouteLMHeadModel.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config + ) + elif ln_position == "in": + RouteMoEQformer = BertMoERouteLMHeadModelLNIn.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config + ) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, moe_encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range) + + return RouteMoEQformer, query_tokens + + + @classmethod + def init_QformerMoE(cls, num_query_token, vision_width, moebert_expert_num, moebert_route_method, moebert_load_balance, moe_topk=1, use_balance_loss=True, moe_weight_type='l2_norm', cross_attention_freq=2,ln_position="out"): + moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + + moe_encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + moe_encoder_config.add_cross_attention = True + moe_encoder_config.cross_attention_freq = cross_attention_freq + moe_encoder_config.query_length = num_query_token + + moe_encoder_config.moebert_expert_num = moebert_expert_num + moe_encoder_config.moebert_route_method = moebert_route_method + moe_encoder_config.moebert_load_balance = moebert_load_balance + moe_encoder_config.moe_topk = moe_topk + moe_encoder_config.use_balance_loss = use_balance_loss + moe_encoder_config.moe_weight_type = moe_weight_type + + if ln_position == "out": + MoEQformer = BertMoELMHeadModel.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config + ) + elif ln_position == "in": + MoEQformer = BertMoELMHeadModelLNIn.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config + ) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, moe_encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=moe_encoder_config.initializer_range) + return MoEQformer, query_tokens + + def init_llm(cls, llama_model_path, freeze_llm=True, lora_r=0, + lora_target_modules=["q_proj","v_proj"], **lora_kargs): + logging.info('Loading LLAMA') + from transformers import LlamaTokenizer + from minigpt4.models.modeling_llama import LlamaForCausalLM + + llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model_path, use_fast=False) + # llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) + # llama_tokenizer.add_special_tokens({'bos_token': ''}) + # llama_tokenizer.add_special_tokens({'eos_token': ''}) + # llama_tokenizer.add_special_tokens({'unk_token': ''}) + + llama_tokenizer.pad_token = llama_tokenizer.unk_token + + llama_model = LlamaForCausalLM.from_pretrained( + llama_model_path, + torch_dtype=torch.float16, + ) + llama_model.resize_token_embeddings(len(llama_tokenizer)) + # self.eos_token_id = self.llm_tokenizer( + # self.llm_tokenizer.eos_token, add_special_tokens=False + # ).input_ids[0] + + if freeze_llm==False and lora_r > 0: + llama_model = prepare_model_for_int8_training(llama_model) + loraconfig = LoraConfig( + r=lora_r, + bias="none", + task_type="CAUSAL_LM", + target_modules=lora_target_modules, + **lora_kargs + ) + llama_model = get_peft_model(llama_model, loraconfig) + + llama_model.print_trainable_parameters() + + else: + for name, param in llama_model.named_parameters(): + param.requires_grad = False + logging.info('Loading LLAMA Done') + return llama_model, llama_tokenizer + + + def init_vision_encoder( + self, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision, freeze_vit=True + ): + assert model_name in [ + "eva_clip_g", + "eva2_clip_L", + "clip_L", + ], "vit model must be eva_clip_g, eva2_clip_L or clip_L" + if model_name == "eva_clip_g": + visual_encoder = create_eva_vit_g( + img_size, drop_path_rate, use_grad_checkpoint, precision + ) +# elif model_name == "eva2_clip_L": +# visual_encoder = create_eva2_vit_L( +# img_size, drop_path_rate, use_grad_checkpoint, precision +# ) + elif model_name == "clip_L": + from minigpt4.models.clip_vit import create_clip_vit_L + visual_encoder = create_clip_vit_L(img_size, use_grad_checkpoint, precision) + ln_vision = LayerNorm(visual_encoder.num_features) + self.vit_name = model_name + pytorch_total_params = sum(p.numel() for p in visual_encoder.parameters()) + print(f'{model_name} clip vit params:') + print(f"{pytorch_total_params * 1e-9:.2} B") + + if freeze_vit: + for name, param in visual_encoder.named_parameters(): + param.requires_grad = False + visual_encoder = visual_encoder.eval() + visual_encoder.train = disabled_train + # freeze ln vision + # for name, param in ln_vision.named_parameters(): + # param.requires_grad = False + # ln_vision = ln_vision.eval() + # ln_vision.train = disabled_train + logging.info("freeze vision encoder but not ln_vision") + + return visual_encoder, ln_vision + + def mean_pool_adjust_query_tokens(self, state_dict, num_query_token): + group = 32 // num_query_token + query_tokens = state_dict['query_tokens'].view(1,num_query_token,group,768) + state_dict['query_tokens'] = torch.mean(query_tokens, dim=2) + return state_dict + + def load_from_pretrained(self, url_or_filename, num_query_token=32): + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location="cpu") + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location="cpu") + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + # state_dict = self.mean_pool_adjust_query_tokens(state_dict, num_query_token) + + msg = self.load_state_dict(state_dict, strict=False) + + # logging.info("Missing keys {}".format(msg.missing_keys)) + logging.info("load checkpoint from %s" % url_or_filename) + + return msg + + def get_optimizer_params(self, weight_decay, lr_scale=1): + + vit_num_layers = self.visual_encoder.get_num_layer() + lr_scales = list(lr_scale ** (vit_num_layers + 1 - i) for i in range(vit_num_layers + 2)) + + parameter_group_names = {} + parameter_group_vars = {} + + for name, param in self.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias"): + group_name = "no_decay" + this_weight_decay = 0. + else: + group_name = "decay" + this_weight_decay = weight_decay + if 'visual_encoder' in name: + layer_id = self.visual_encoder.get_num_layer(name.replace('visual_encoder.','')) + group_name = "vit_layer_%d_%s" % (layer_id, group_name) + else: + layer_id = None + + if group_name not in parameter_group_names: + if layer_id is not None: + scale = lr_scales[layer_id] + else: + scale = 1 + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale + } + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + # import json + # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + optim_params = list(parameter_group_vars.values()) + return optim_params + + def _lemmatize(self, answers): + def apply(answer): + doc = self.lemmatizer(answer) + + words = [] + for token in doc: + if token.pos_ in ["NOUN", "VERB"]: + words.append(token.lemma_) + else: + words.append(token.text) + answer = " ".join(words) + + return answer + + return [apply(answer) for answer in answers] + + @property + def lemmatizer(self): + if self._lemmatizer is None: + try: + import spacy + + self._lemmatizer = spacy.load("en_core_web_sm") + except ImportError: + logging.error( + """ + Please install spacy and en_core_web_sm model to apply lemmatization. + python -m spacy download en_core_web_sm + OR + import spacy.cli + spacy.cli.download("en_core_web_sm") + """ + ) + exit(1) + + return self._lemmatizer + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +def compute_sim_matrix(model, data_loader, **kwargs): + k_test = kwargs.pop("k_test") + + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation:" + + logging.info("Computing features for evaluation...") + start_time = time.time() + + texts = data_loader.dataset.text + num_text = len(texts) + text_bs = 256 + text_ids = [] + text_embeds = [] + text_atts = [] + for i in range(0, num_text, text_bs): + text = texts[i : min(num_text, i + text_bs)] + text_input = model.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=35, + return_tensors="pt", + ).to(model.device) + text_feat = model.forward_text(text_input) + text_embed = F.normalize(model.text_proj(text_feat)) + text_embeds.append(text_embed) + text_ids.append(text_input.input_ids) + text_atts.append(text_input.attention_mask) + + text_embeds = torch.cat(text_embeds, dim=0) + text_ids = torch.cat(text_ids, dim=0) + text_atts = torch.cat(text_atts, dim=0) + + vit_feats = [] + image_embeds = [] + for samples in data_loader: + image = samples["image"] + + image = image.to(model.device) + image_feat, vit_feat = model.forward_image(image) + image_embed = model.vision_proj(image_feat) + image_embed = F.normalize(image_embed, dim=-1) + + vit_feats.append(vit_feat.cpu()) + image_embeds.append(image_embed) + + vit_feats = torch.cat(vit_feats, dim=0) + image_embeds = torch.cat(image_embeds, dim=0) + + sims_matrix = [] + for image_embed in image_embeds: + sim_q2t = image_embed @ text_embeds.t() + sim_i2t, _ = sim_q2t.max(0) + sims_matrix.append(sim_i2t) + sims_matrix = torch.stack(sims_matrix, dim=0) + + score_matrix_i2t = torch.full( + (len(data_loader.dataset.image), len(texts)), -100.0 + ).to(model.device) + + num_tasks = dist_utils.get_world_size() + rank = dist_utils.get_rank() + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[topk_idx], + text_atts=text_atts[topk_idx], + ).float() + score_matrix_i2t[start + i, topk_idx] = score + topk_sim + + sims_matrix = sims_matrix.t() + score_matrix_t2i = torch.full( + (len(texts), len(data_loader.dataset.image)), -100.0 + ).to(model.device) + + step = sims_matrix.size(0) // num_tasks + 1 + start = rank * step + end = min(sims_matrix.size(0), start + step) + + for i, sims in enumerate( + metric_logger.log_every(sims_matrix[start:end], 50, header) + ): + topk_sim, topk_idx = sims.topk(k=k_test, dim=0) + image_inputs = vit_feats[topk_idx.cpu()].to(model.device) + score = model.compute_itm( + image_inputs=image_inputs, + text_ids=text_ids[start + i].repeat(k_test, 1), + text_atts=text_atts[start + i].repeat(k_test, 1), + ).float() + score_matrix_t2i[start + i, topk_idx] = score + topk_sim + + if dist_utils.is_dist_avail_and_initialized(): + dist.barrier() + torch.distributed.all_reduce( + score_matrix_i2t, op=torch.distributed.ReduceOp.SUM + ) + torch.distributed.all_reduce( + score_matrix_t2i, op=torch.distributed.ReduceOp.SUM + ) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Evaluation time {}".format(total_time_str)) + + return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy() diff --git a/minigpt4/models/blip2_outputs.py b/minigpt4/models/blip2_outputs.py new file mode 100644 index 0000000..9d18ddc --- /dev/null +++ b/minigpt4/models/blip2_outputs.py @@ -0,0 +1,116 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from dataclasses import dataclass +from typing import Optional + +import torch +from transformers.modeling_outputs import ( + ModelOutput, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) + + +@dataclass +class BlipSimilarity(ModelOutput): + sim_i2t: torch.FloatTensor = None + sim_t2i: torch.FloatTensor = None + + sim_i2t_m: Optional[torch.FloatTensor] = None + sim_t2i_m: Optional[torch.FloatTensor] = None + + sim_i2t_targets: Optional[torch.FloatTensor] = None + sim_t2i_targets: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipIntermediateOutput(ModelOutput): + """ + Data class for intermediate outputs of BLIP models. + + image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim). + text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim). + + image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim). + text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim). + + encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder. + encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs. + + decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder. + decoder_labels (torch.LongTensor): labels for the captioning loss. + + itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2). + itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,) + + """ + + # uni-modal features + image_embeds: torch.FloatTensor = None + text_embeds: Optional[torch.FloatTensor] = None + + image_embeds_m: Optional[torch.FloatTensor] = None + text_embeds_m: Optional[torch.FloatTensor] = None + + # intermediate outputs of multimodal encoder + encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None + + itm_logits: Optional[torch.FloatTensor] = None + itm_labels: Optional[torch.LongTensor] = None + + # intermediate outputs of multimodal decoder + decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None + decoder_labels: Optional[torch.LongTensor] = None + + +@dataclass +class BlipOutput(ModelOutput): + # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional. + sims: Optional[BlipSimilarity] = None + + intermediate_output: BlipIntermediateOutput = None + + loss: Optional[torch.FloatTensor] = None + + loss_itc: Optional[torch.FloatTensor] = None + + loss_itm: Optional[torch.FloatTensor] = None + + loss_lm: Optional[torch.FloatTensor] = None + + +@dataclass +class BlipOutputWithLogits(BlipOutput): + logits: torch.FloatTensor = None + logits_m: torch.FloatTensor = None + + +@dataclass +class BlipOutputFeatures(ModelOutput): + """ + Data class of features from BlipFeatureExtractor. + + Args: + image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional + image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional + text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional + text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional + + The first embedding or feature is for the [CLS] token. + + Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space. + """ + + image_embeds: Optional[torch.FloatTensor] = None + image_embeds_proj: Optional[torch.FloatTensor] = None + + text_embeds: Optional[torch.FloatTensor] = None + text_embeds_proj: Optional[torch.FloatTensor] = None + + multimodal_embeds: Optional[torch.FloatTensor] = None diff --git a/minigpt4/models/blip2_qformer.py b/minigpt4/models/blip2_qformer.py new file mode 100644 index 0000000..6abb26c --- /dev/null +++ b/minigpt4/models/blip2_qformer.py @@ -0,0 +1,538 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import logging + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.cuda.amp import autocast as autocast +from torch.nn import functional as F + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import all_gather_with_grad, concat_all_gather +from minigpt4.models.blip2 import ( + Blip2Base, + compute_sim_matrix, + disabled_train, +) +from minigpt4.models.blip2_outputs import BlipOutput, BlipOutputFeatures + + +@registry.register_model("blip2") +@registry.register_model("blip2_feature_extractor") +class Blip2Qformer(Blip2Base): + """ + BLIP2 first-stage model with Q-former and ViT. + Supported model types: + - pretrained: pretrained model with vit-g + - pretrain_vitL: pretrained model with vit-large + - coco: fintuned model on coco + Usage: + >>> from minigpt4.models import load_model + >>> model = load_model("blip2", "pretrain") + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain": "configs/models/blip2/blip2_pretrain.yaml", + "pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml", + "coco": "configs/models/blip2/blip2_coco.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + num_query_token=32, + cross_attention_freq=2, + embed_dim=256, + max_txt_len=32, + ): + super().__init__() + + self.tokenizer = self.init_tokenizer() + + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + logging.info("freeze vision encoder") + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features, cross_attention_freq + ) + self.Qformer.resize_token_embeddings(len(self.tokenizer)) + state_dict = self.Qformer.state_dict() + for name, param in self.Qformer.named_parameters(): + if "_query" in name: + # bert.encoder.layer.10.intermediate_query.dense.weight / bias + # bert.encoder.layer.10.output_query.dense.weight / bias + # bert.encoder.layer.10.output_query.LayerNorm.weight / bias + key_orig = name.replace("_query", "") + param.data.copy_(state_dict[key_orig]) # copy state_dict[key_orig] to param + + self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) + self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim) + + self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2) + + self.temp = nn.Parameter(0.07 * torch.ones([])) + + self.max_txt_len = max_txt_len + + def forward(self, samples): + image = samples["image"] + text = samples["text_input"] + + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + use_cache=True, + return_dict=True, + ) + + image_feats = F.normalize( + self.vision_proj(query_output.last_hidden_state), dim=-1 + ) + + text_tokens = self.tokenizer( + text, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + text_output = self.Qformer.bert( + text_tokens.input_ids, + attention_mask=text_tokens.attention_mask, + return_dict=True, + ) + text_feat = F.normalize( + self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 + ) + + ###============== Image-text Contrastive ===================### + image_feats_all = concat_all_gather( + image_feats + ) # [batch_size*num_gpu, num_query_tokens, embed_dim] + text_feat_all = concat_all_gather(text_feat) # [batch_size*num_gpu, embed_dim] + + sim_q2t = torch.matmul( + image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1) + ).squeeze() + # [batch_size, batch_size*num_gpu, num_query_tokens] + + # image-text similarity: aggregate across all query tokens + sim_i2t, _ = sim_q2t.max(-1) + sim_i2t = sim_i2t / self.temp + + # text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens] + sim_t2q = torch.matmul( + text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1) + ).squeeze() + + # text-image similarity: aggregate across all query tokens + sim_t2i, _ = sim_t2q.max(-1) + sim_t2i = sim_t2i / self.temp # [batch_size, batch_size*num_gpu] + + rank = dist.get_rank() + bs = image.size(0) + targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to( + image.device + ) + + if "image_id" in samples.keys(): #coco retrieval finetuning + image_ids = samples["image_id"].view(-1,1) + image_ids_all = concat_all_gather(image_ids) + pos_idx = torch.eq(image_ids, image_ids_all.t()).float() + sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) + sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(1) + + loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean() + loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean() + loss_itc = (loss_t2i+loss_i2t)/2 + else: + loss_itc = ( + F.cross_entropy(sim_i2t, targets, label_smoothing=0.1) + + F.cross_entropy(sim_t2i, targets, label_smoothing=0.1) + ) / 2 + + ###============== Image-text Matching ===================### + text_input_ids_world = concat_all_gather(text_tokens.input_ids) + text_attention_mask_world = concat_all_gather(text_tokens.attention_mask) + image_embeds_world = all_gather_with_grad(image_embeds) + with torch.no_grad(): + if "image_id" in samples.keys(): + mask = torch.eq(image_ids, image_ids_all.t()) + sim_t2i.masked_fill_(mask, -10000) + sim_i2t.masked_fill_(mask, -10000) + else: + sim_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000) + sim_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000) + + weights_t2i = F.softmax(sim_t2i, dim=1) + weights_i2t = F.softmax(sim_i2t, dim=1) + + # select a negative image for each text + image_embeds_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_t2i[b], 1).item() + image_embeds_neg.append(image_embeds_world[neg_idx]) + image_embeds_neg = torch.stack(image_embeds_neg, dim=0) + + # select a negative text for each image + text_ids_neg = [] + text_atts_neg = [] + for b in range(bs): + neg_idx = torch.multinomial(weights_i2t[b], 1).item() + text_ids_neg.append(text_input_ids_world[neg_idx]) + text_atts_neg.append(text_attention_mask_world[neg_idx]) + + text_ids_neg = torch.stack(text_ids_neg, dim=0) + text_atts_neg = torch.stack(text_atts_neg, dim=0) + + text_ids_all = torch.cat( + [text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0 + ) # pos, pos, neg + text_atts_all = torch.cat( + [text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg], + dim=0, + ) + + query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1) + query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to( + image.device + ) + attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1) + + image_embeds_all = torch.cat( + [image_embeds, image_embeds_neg, image_embeds], dim=0 + ) # pos, neg, pos + image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to( + image.device + ) + + output_itm = self.Qformer.bert( + text_ids_all, + query_embeds=query_tokens_itm, + attention_mask=attention_mask_all, + encoder_hidden_states=image_embeds_all, + encoder_attention_mask=image_atts_all, + return_dict=True, + ) + + vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :] + vl_output = self.itm_head(vl_embeddings) + logits = vl_output.mean(dim=1) + + itm_labels = torch.cat( + [torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)], + dim=0, + ).to(image.device) + loss_itm = F.cross_entropy(logits, itm_labels) + + ##================= Image Captioning ========================## + decoder_input_ids = text_tokens.input_ids.clone() + decoder_input_ids[:, 0] = self.tokenizer.bos_token_id + labels = decoder_input_ids.masked_fill( + decoder_input_ids == self.tokenizer.pad_token_id, -100 + ) + + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + image.device + ) + attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1) + lm_output = self.Qformer( + decoder_input_ids, + attention_mask=attention_mask, + past_key_values=query_output.past_key_values, + return_dict=True, + labels=labels, + ) + + loss_lm = lm_output.loss + + return BlipOutput( + loss=loss_itc + loss_itm + loss_lm, + loss_itc=loss_itc, + loss_itm=loss_itm, + loss_lm=loss_lm, + ) + + @torch.no_grad() + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=3, + max_length=30, + min_length=10, + top_p=0.9, + repetition_penalty=1.0, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling. + num_beams (int): Number of beams for beam search. 1 means no beam search. + max_length (int): The maximum length of the sequence to be generated. + min_length (int): The minimum length of the sequence to be generated. + top_p (float): The cumulative probability for nucleus sampling. + repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. + num_captions (int): Number of captions to be generated for each image. + Returns: + captions (list): A list of strings of length batch_size * num_captions. + """ + image = samples["image"] + image_embeds = self.ln_vision(self.visual_encoder(image)) + + if not use_nucleus_sampling: + image_embeds = image_embeds.repeat_interleave(num_beams, dim=0) + else: + num_beams = 1 + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + model_kwargs = { + "encoder_hidden_states": image_embeds, + "encoder_attention_mask": image_atts, + } + + input_ids = ( + torch.LongTensor(image.size(0), 1) + .fill_(self.tokenizer.bos_token_id) + .to(image.device) + ) + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + outputs = self.Qformer.generate( + input_ids=input_ids, + query_embeds=query_tokens, + max_length=max_length, + min_length=min_length, + num_beams=num_beams, + do_sample=use_nucleus_sampling, + top_p=top_p, + eos_token_id=self.tokenizer.sep_token_id, + pad_token_id=self.tokenizer.pad_token_id, + **model_kwargs + ) + captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) + return captions + + def forward_image(self, image): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + return query_output.last_hidden_state, image_embeds + + def forward_text(self, text_tokens): + text_output = self.Qformer.bert( + text_tokens.input_ids, + attention_mask=text_tokens.attention_mask, + return_dict=True, + ) + return text_output.last_hidden_state[:, 0, :] + + def compute_itm(self, image_inputs, text_ids, text_atts): + image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to( + image_inputs.device + ) + query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + image_inputs.device + ) + attention_mask = torch.cat([query_atts, text_atts], dim=1) + output_itm = self.Qformer.bert( + text_ids, + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=image_inputs, + encoder_attention_mask=image_atts, + return_dict=True, + ) + vl_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] + itm_logit = self.itm_head(vl_embeddings) + itm_logit = itm_logit[:, :, 1].mean(dim=1) + return itm_logit + + @torch.no_grad() + def extract_features(self, samples, mode="multimodal"): + """ + Extract features for multimodal or unimodal samples. + Args: + samples (dict): A dictionary of samples, containing the following keys: + - image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image. + Raw images should be preprocessed before being passed to feature extractor. + - text_input (list): A list of strings containing the text, length B. + mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image". + If "multimodal", return image features and multimodal features; + if "text", return text features; + if "image", return image features. + Default: "multimodal". + Returns: + BlipOutputFeatures: A BlipOutputFeatures object containing the features. + See lavis/models/blip_models/blip_outputs.py for more details. + """ + image = samples.get("image") + caption = samples.get("text_input") + + # assert mode is one of "image", "text", "multimodal" + assert mode in [ + "image", + "text", + "multimodal", + ], "mode must be one of 'image', 'text', 'multimodal'" + + # initalize output + image_embeds, text_embeds, multimodal_embeds = None, None, None + image_features, text_features = None, None + + if mode == "image": + assert ( + image is not None + ), "Image is not provided for mode 'image' or 'multimodal'" + # return query features + with self.maybe_autocast(): + image_embeds_frozen = self.ln_vision(self.visual_encoder(image)) + image_embeds_frozen = image_embeds_frozen.float() + image_atts = torch.ones( + image_embeds_frozen.size()[:-1], dtype=torch.long + ).to(self.device) + query_tokens = self.query_tokens.expand( + image_embeds_frozen.shape[0], -1, -1 + ) + + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds_frozen, + encoder_attention_mask=image_atts, + return_dict=True, + ) + image_embeds = query_output.last_hidden_state + image_features = F.normalize(self.vision_proj(image_embeds), dim=-1) + + elif mode == "text": + assert ( + caption is not None + ), "text input is None for mode 'text' or 'multimodal'" + + # return text features + text = self.tokenizer(caption, return_tensors="pt", padding=True).to( + self.device + ) + + text_output = self.Qformer.bert( + text.input_ids, + attention_mask=text.attention_mask, + return_dict=True, + ) + text_embeds = text_output.last_hidden_state + text_features = self.text_proj(text_embeds) + text_features = F.normalize(text_features, dim=-1) + + elif mode == "multimodal": + # return multimodel query features + with self.maybe_autocast(): + image_embeds_frozen = self.ln_vision(self.visual_encoder(image)) + image_embeds_frozen = image_embeds_frozen.float() + image_atts = torch.ones( + image_embeds_frozen.size()[:-1], dtype=torch.long + ).to(self.device) + query_tokens = self.query_tokens.expand( + image_embeds_frozen.shape[0], -1, -1 + ) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( + self.device + ) + + text = self.tokenizer(caption, return_tensors="pt", padding=True).to( + self.device + ) + attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) + + output = self.Qformer.bert( + text.input_ids, + query_embeds=query_tokens, + attention_mask=attention_mask, + encoder_hidden_states=image_embeds_frozen, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + multimodal_embeds = output.last_hidden_state[:, : query_tokens.size(1), :] + + return BlipOutputFeatures( + image_embeds=image_embeds, + image_embeds_proj=image_features, + text_embeds=text_embeds, + text_embeds_proj=text_features, + multimodal_embeds=multimodal_embeds, + ) + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + cross_attention_freq = cfg.get("cross_attention_freq", 2) + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + + max_txt_len = cfg.get("max_txt_len", 32) + + model = cls( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + num_query_token=num_query_token, + cross_attention_freq=cross_attention_freq, + max_txt_len=max_txt_len, + ) + model.load_checkpoint_from_config(cfg) + + return model + + def compute_sim_matrix(self, data_loader, task_cfg): + """ + Compute similarity i2t, t2i matrix for the given data loader. + """ + k_test = task_cfg.k_test + + return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test) diff --git a/minigpt4/models/blip2_t5_instruct.py b/minigpt4/models/blip2_t5_instruct.py new file mode 100644 index 0000000..abd4702 --- /dev/null +++ b/minigpt4/models/blip2_t5_instruct.py @@ -0,0 +1,503 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import logging +import string +import random +import copy +import json +import os +import numpy as np +import re + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast as autocast +from transformers import T5TokenizerFast + +from minigpt4.common.registry import registry +from minigpt4.models.blip2 import Blip2Base, disabled_train +from minigpt4.models.modeling_t5 import T5Config, T5ForConditionalGeneration +from transformers.modeling_outputs import BaseModelOutput + +@registry.register_model("blip2_t5_qformer_moe") +class Blip2T5InstructQformerMoE(Blip2Base): + """ + BLIP2 Instruct T5 model Qformer MoE + Supported model types: + - flant5xxl + Usage: + >>> from minigpt4.models import load_model + >>> import torch + >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + >>> model = load_model("blip2_t5_qformer_moe", "flant5xxl", device=device) + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "flant5xxl": "configs/models/blip2/blip2_instruct_flant5xxl_qformer_moe.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + freeze_llm=True, + freeze_qformer=False, + freeze_t5_proj=False, + num_query_token=32, + t5_model="google/flan-t5-xl", + prompt="", + max_txt_len=128, + max_output_txt_len=256, + apply_lemmatizer=False, + qformer_text_input=True, + moebert_expert_num=5, + moebert_route_method="gate-sentence", + moebert_load_balance = 0.1, + moe_topk = 1, + use_balance_loss=True, + ): + """ + apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas. + """ + super().__init__() + + self.tokenizer = self.init_tokenizer(truncation_side="left") + + print("Init BLIP2 Instruct Flant5xxl Prompt MoE") + + print('Initing & Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + # freeze vit + # if freeze_vit: + # for name, param in self.visual_encoder.named_parameters(): + # param.requires_grad = False + # self.visual_encoder = self.visual_encoder.eval() + # self.visual_encoder.train = disabled_train + # # freeze ln vision + # for name, param in self.ln_vision.named_parameters(): + # param.requires_grad = False + # self.ln_vision = self.ln_vision.eval() + # self.ln_vision.train = disabled_train + # logging.info("freeze vision encoder") + # print('Loading VIT Done') + + print('Initing MoE Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + + if not qformer_text_input: + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + else: + self.Qformer.resize_token_embeddings(len(self.tokenizer)) + self.Qformer.cls = None + + print('Loading T5') + self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='left') + self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='right') + t5_config = T5Config.from_pretrained(t5_model) + t5_config.dense_act_fn = "gelu" + self.t5_model = T5ForConditionalGeneration.from_pretrained( + t5_model, config=t5_config, use_safetensors=False + ) + # freeze t5 llm + if freeze_llm: + for name, param in self.t5_model.named_parameters(): + param.requires_grad = False + param.data = param.data.bfloat16() + print('Loading T5 Done') + + + print("Initing t5 linear projection") + self.t5_proj = nn.Linear( + self.Qformer.config.hidden_size, self.t5_model.config.hidden_size + ) + + # load BLIP2 Pretrain + print("Loading BLIP2 Parameters from :", q_former_model) + self.load_from_pretrained(url_or_filename=q_former_model) + + # freeze qformer + if freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + logging.info("freeze Qformer") + + # After loading, freeze t5_proj + if freeze_t5_proj: + for name, param in self.t5_proj.named_parameters(): + param.requires_grad = False + self.t5_proj = self.t5_proj.eval() + self.t5_proj.train = disabled_train + + self.max_txt_len = max_txt_len + self.max_output_txt_len = max_output_txt_len + self.prompt = prompt + + self._apply_lemmatizer = apply_lemmatizer + self._lemmatizer = None + + self.qformer_text_input = qformer_text_input + + def forward(self, samples): + + image = samples["image"] + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + bz = image_embeds.shape[0] + query_tokens = self.query_tokens.expand(bz, -1, -1) + + ## Q-former Forward with one query tokens + if self.qformer_text_input: + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) + + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:] + # gate_loss = query_output.gate_loss + inputs_t5 = self.t5_proj(query_output_to_linear) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + with self.maybe_autocast(dtype=torch.bfloat16): + input_tokens = self.t5_tokenizer( + samples["llm_input"], + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + output_tokens = self.t5_output_tokenizer( + samples["text_output"], + padding="longest", + truncation=True, + max_length=self.max_output_txt_len, + return_tensors="pt", + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + targets = output_tokens.input_ids.masked_fill( + output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100 + ) + + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + outputs = self.t5_model( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + decoder_attention_mask=output_tokens.attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + # final_loss = loss + self.moebert_load_balance * gate_loss + return {"loss": loss} + + @torch.no_grad() + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=5, + max_length=256, + min_length=1, + top_p=0.9, + repetition_penalty=1.5, + length_penalty=1.0, + num_captions=1, + temperature=1, + ): + if "prompt" in samples.keys(): + prompt = samples["prompt"] + else: + prompt = self.prompt + + image = samples["image"] + + bs = image.size(0) + + if isinstance(prompt, str): + prompt = [prompt] * bs + else: + assert len(prompt) == bs, "The number of prompts must be equal to the batch size." + + # For TextCaps + if "ocr_tokens" in samples.keys() and "{}" in prompt[0]: + prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)] + + # image embed + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + query_tokens = self.query_tokens.expand(bs, -1, -1) + + if self.qformer_text_input: + # remove ocr tokens in q_former (for eval textvqa) + # qformer_prompt = prompt + # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt] + + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) + + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:] + + inputs_t5 = self.t5_proj(query_output_to_linear) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + input_tokens = self.t5_tokenizer( + prompt, + padding="longest", + return_tensors="pt" + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + with self.maybe_autocast(dtype=torch.bfloat16): + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + outputs = self.t5_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_new_tokens=max_length, + min_length=min_length, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, + ) + output_text = self.t5_tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + + return output_text + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=-1, + **kwargs + ): + if isinstance(samples["llm_input"], str): + samples["llm_input"] = [samples["llm_input"]] + + if prompt: + if prompt.count("{}") == 2: + if 'ocr_tokens' in samples: + text_input = [ + prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["llm_input"][i]) + for i in range(len(samples["llm_input"]))] + elif 'choices' in samples: + text_input = [] + for i in range(len(samples["llm_input"])): + this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])] + this_choices = " ".join(this_choices) + text_input.append(prompt.format(samples["llm_input"][i], this_choices)) + else: + text_input = [prompt.format(question) for question in samples["llm_input"]] + else: + text_input = samples["llm_input"] + + samples["prompt"] = text_input + + output_text = self.generate( + samples, + num_beams=num_beams, + max_length=max_len, + min_length=min_len, + length_penalty=length_penalty + ) + + if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]): + output_text = self._lemmatize(output_text) + + return output_text + + def _lemmatize(self, answers): + def apply(answer): + doc = self.lemmatizer(answer) + + words = [] + for token in doc: + if token.pos_ in ["NOUN", "VERB"]: + words.append(token.lemma_) + else: + words.append(token.text) + answer = " ".join(words) + + return answer + + return [apply(answer) for answer in answers] + + @property + def lemmatizer(self): + if self._lemmatizer is None: + try: + import spacy + + self._lemmatizer = spacy.load("en_core_web_sm") + except ImportError: + logging.error( + """ + Please install spacy and en_core_web_sm model to apply lemmatization. + python -m spacy download en_core_web_sm + OR + import spacy.cli + spacy.cli.download("en_core_web_sm") + """ + ) + exit(1) + + return self._lemmatizer + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + t5_model = cfg.get("t5_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_llm = cfg.get("freeze_llm", True) + freeze_qformer = cfg.get("freeze_qformer", False) + freeze_t5_proj = cfg.get("freeze_t5_proj", False) + + prompt = cfg.get("prompt", "") + max_txt_len = cfg.get("max_txt_len", 128) + max_output_txt_len = cfg.get("max_output_txt_len", 256) + apply_lemmatizer = cfg.get("apply_lemmatizer", False) + + qformer_text_input = cfg.get("qformer_text_input", True) + + moebert_expert_num = cfg.get("moebert_expert_num", 5) + moebert_route_method = cfg.get("moebert_route_method", "gate-sentence") + moebert_load_balance = cfg.get("moebert_load_balance", 0.1) + moe_topk = cfg.get("moe_topk", 1) + use_balance_loss = cfg.get("use_balance_loss", True) + + model = cls( + vit_model=vit_model, + img_size=img_size, + q_former_model=q_former_model, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + freeze_llm=freeze_llm, + freeze_qformer=freeze_qformer, + freeze_t5_proj=freeze_t5_proj, + num_query_token=num_query_token, + t5_model=t5_model, + prompt=prompt, + max_txt_len=max_txt_len, + max_output_txt_len=max_output_txt_len, + apply_lemmatizer=apply_lemmatizer, + qformer_text_input=qformer_text_input, + moebert_expert_num=moebert_expert_num, + moebert_route_method=moebert_route_method, + moebert_load_balance=moebert_load_balance, + moe_topk=moe_topk, + use_balance_loss=use_balance_loss, + ) + + if qformer_text_input: + # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal) + model.load_from_pretrained( + url_or_filename="/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_pretrained/blip2_pretrained.pth" + ) + + model.load_checkpoint_from_config(cfg) + + # check update params + print("Updating following parameters:") + for name, param in model.named_parameters(): + if param.requires_grad == True: + print(name) + + # layer self attention: 2,363,904 + # layer pure ffn : 4,723,968 + # layer expert ffn : 4,723,968 + # layer cross attention: 3,346,944 + + return model diff --git a/minigpt4/models/blip2_t5_instruct_pro_moe.py b/minigpt4/models/blip2_t5_instruct_pro_moe.py new file mode 100644 index 0000000..e7c4798 --- /dev/null +++ b/minigpt4/models/blip2_t5_instruct_pro_moe.py @@ -0,0 +1,1127 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import logging +import string +import random +import copy +import json +import os +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast as autocast +from transformers import T5TokenizerFast + +from minigpt4.common.registry import registry +from minigpt4.models.blip2 import Blip2Base, disabled_train +from minigpt4.models.modeling_t5 import T5Config, T5ForConditionalGeneration +from transformers.modeling_outputs import BaseModelOutput + +from minigpt4.models.moe.prompt_moe import init_query_token_candidates, PrePromptMoE, PostPromptMoE + +@registry.register_model("blip2_t5_instruct_pro_moe") +class Blip2T5InstructPromptMOE(Blip2Base): + """ + BLIP2 Instruct T5 model Prompt MoE + Supported model types: + - flant5xxl + Usage: + >>> from minigpt4.models import load_model + >>> import torch + >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + >>> model = load_model("blip2_t5_instruct_pro_moe", "flant5xxl", device=device) + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "flant5xxl": "configs/models/blip2/blip2_instruct_flant5xxl_prompt_moe.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + freeze_llm=True, + freeze_qformer=False, + freeze_t5_proj=False, + num_query_token=32, + t5_model="google/flan-t5-xl", + prompt="", + max_txt_len=128, + max_output_txt_len=256, + apply_lemmatizer=False, + num_few_shot_examples=0, + few_shot_prob=0, + qformer_text_input=True, + repeat_to_init_qt_candidates=True, + num_qt_candidates=5, + moe_topk=2, + moe_position="pre", + embed_extract="t5", + eval_gate_save=False, + train_gate_save=False, + gate_save_path="", + ): + """ + apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas. + """ + super().__init__() + + self.tokenizer = self.init_tokenizer(truncation_side="left") + + print("Init BLIP2 Instruct Flant5xxl Prompt MoE") + + print('Initing & Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + # freeze vit + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + # freeze ln vision + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + self.ln_vision.train = disabled_train + logging.info("freeze vision encoder") + print('Loading VIT Done') + + + print('Initing Q-Former') + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + if not qformer_text_input: + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + else: + self.Qformer.resize_token_embeddings(len(self.tokenizer)) + self.Qformer.cls = None + + print('Loading T5') + self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='left') + self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='right') + t5_config = T5Config.from_pretrained(t5_model) + t5_config.dense_act_fn = "gelu" + self.t5_model = T5ForConditionalGeneration.from_pretrained( + t5_model, config=t5_config, use_safetensors=False + ) + # freeze t5 llm + if freeze_llm: + for name, param in self.t5_model.named_parameters(): + param.requires_grad = False + param.data = param.data.bfloat16() + print('Loading T5 Done') + + + print("Initing t5 linear projection") + self.t5_proj = nn.Linear( + self.Qformer.config.hidden_size, self.t5_model.config.hidden_size + ) + + # load BLIP2 Pretrain + print("Loading BLIP2 Parameters from :", q_former_model) + self.load_from_pretrained(url_or_filename=q_former_model) + + + print('Init query token candidates') + self.moe_position = moe_position + if num_qt_candidates > 1: + + self.query_token_candidates = init_query_token_candidates(num_query_token, num_qt_candidates) # shape:[num_qt_candidates, num_query_token, q_former_hidden_size] + if repeat_to_init_qt_candidates: + self.query_token_candidates = torch.nn.Parameter(self.query_tokens.repeat(num_qt_candidates, 1, 1)) + self.query_tokens.requires_grad = False + print(self.query_token_candidates.shape) + + if self.moe_position == "pre": # PromptMoE + Qformer + self.embed_extract = embed_extract + if self.embed_extract == "t5": + self.text_embed_size = self.t5_model.config.hidden_size + + elif self.embed_extract == "blip2_pretrain": + from minigpt4.models import load_model + self.embed_extractor = load_model( + "blip2", + "pretrain", + is_eval=True, + ) # BLIP2 first-stage model with Q-former and ViT. + for name, param in self.embed_extractor.named_parameters(): + param.requires_grad = False + # self.text_embed_size = self.Qformer.config.hidden_size + self.text_embed_size = self.embed_extractor.text_proj.out_features + + elif self.embed_extract == "random": + self.text_embed_size = self.Qformer.config.hidden_size + self.PromptMoE = PrePromptMoE(self.text_embed_size, num_qt_candidates, self.query_token_candidates, route_method="gate-single-token", topk=moe_topk) + + elif moe_position == "post": # Qformer + PromptMoE + self.text_embed_size = self.Qformer.config.hidden_size + self.PromptMoE = PostPromptMoE(self.text_embed_size, num_qt_candidates, topk=moe_topk) + + + # freeze qformer + if freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + logging.info("freeze Qformer") + + # After loading, freeze t5_proj + if freeze_t5_proj: + for name, param in self.t5_proj.named_parameters(): + param.requires_grad = False + self.t5_proj = self.t5_proj.eval() + self.t5_proj.train = disabled_train + + self.max_txt_len = max_txt_len + self.max_output_txt_len = max_output_txt_len + self.prompt = prompt + + self._apply_lemmatizer = apply_lemmatizer + self._lemmatizer = None + + self.num_few_shot_examples = num_few_shot_examples + self.few_shot_prob = few_shot_prob + + self.qformer_text_input = qformer_text_input + + self.num_qt_candidates = num_qt_candidates + self.gate_save_path = gate_save_path + self.train_gate_save = train_gate_save + self.eval_gate_save = eval_gate_save + if gate_save_path!="" and (not os.path.exists(gate_save_path)): + print(gate_save_path) + os.mkdir(gate_save_path) + + def forward(self, samples): + # print('-----------------') + # print(samples["text_input"]) + # print(samples.keys()) + # print(samples) + # print(samples["text_output"]) + # print('-----------------') + import torch + samples = { + 'text_input':["What is around the open window?", # n23181 + "Is the ground blue or brown?", # n168412 + "What color are the pants?", # n446242 + "What is the airplane flying above?"], # n414992 + 'text_output':["drapes", + "brown", + "red", + "ocean" + ], + 'image': torch.randn(4, 3, 224, 224).half().to(device) + } + + image = samples["image"] + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + bz = image_embeds.shape[0] + + if self.moe_position == "pre": + if self.num_qt_candidates > 1: + ## extract text_embeds + with self.maybe_autocast(dtype=torch.bfloat16): + if self.embed_extract == "t5": + text_embeds = self._extract_text_embed_by_t5(samples['q_input'], samples['text_output'], image.device) + elif self.embed_extract == "blip2_pretrain": + text_embeds = self._extract_text_embed_by_qformer_pretrain_s1(samples['q_input'], image.device) + elif self.embed_extract == "random": + text_embeds = torch.randn(bz, 1, self.text_embed_size ) + ## select proper query_tokens by prompt moe + select_query_tokens, balance_loss, importance_loss, gate_load, gate = self.PromptMoE._forward_gate_single_token(text_embeds) + query_tokens = select_query_tokens # torch.Size([bz, 32, 768]) + else: + query_tokens = self.query_tokens.expand(bz, -1, -1) + balance_loss, importance_loss = 0, 0 + + ## Q-former Forward with one query tokens + if self.qformer_text_input: + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) + + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:] + + + elif self.moe_position == "post": + # self.query_token_candidates : size[num_qt_candidates, 32, 768] + candi_query_tokens = self.query_token_candidates.expand(bz, -1, -1, -1).reshape(-1, self.query_token_candidates.shape[1], self.query_token_candidates.shape[2]) # size[num_qt_candidates*bz, 32, 768] + + image_embeds_repeat = image_embeds.repeat_interleave(self.num_qt_candidates, dim=0) + image_atts_repeat = image_atts.repeat_interleave(self.num_qt_candidates, dim=0) + + ## Q-former Forward with candidates query tokens + if self.qformer_text_input: + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + + text_Qformer_input_ids_repeat = text_Qformer.input_ids.repeat_interleave(self.num_qt_candidates, dim=0) # [bz*num_qt_candidates, batch_seq_len] + text_Qformer_attn_mask_repeat = text_Qformer.attention_mask.repeat_interleave(self.num_qt_candidates, dim=0) # [bz*num_qt_candidates, batch_seq_len] + + query_atts = torch.ones(candi_query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer_attn_mask_repeat],dim=1) + + query_output = self.Qformer.bert( + text_Qformer_input_ids_repeat, + attention_mask=Qformer_atts, + query_embeds=candi_query_tokens, + encoder_hidden_states=image_embeds_repeat, + encoder_attention_mask=image_atts_repeat, + return_dict=True, + ) # query_output.last_hidden_state size [torch.Size([bz*num_qt_candidates, 32+batch_seq_len, 768])] + query_output_to_linear = query_output.last_hidden_state[:,:self.query_token_candidates.size(1),:] + else: + query_output = self.Qformer.bert( + query_embeds=candi_query_tokens, + encoder_hidden_states=image_embeds_repeat, + encoder_attention_mask=image_atts_repeat, + return_dict=True, + ) # query_output.last_hidden_state size [torch.Size([bz*num_qt_candidates, 32, 768])] + # [(sample1, query1), (sample1, query2),..., (sample2, query1),(sample2, query2), ... , (sample_bz, query1),..., (sample_bz, queryn)] + text_cls = query_output.last_hidden_state[:,self.query_token_candidates.size(1),:] # torch.Size([bz*num_qt_candidates, 768]) + text_cls_split = text_cls.view(bz, self.num_qt_candidates, -1) # torch.Size([bz, num_qt_candidates, 768]) + query_tokens_output = query_output.last_hidden_state[:, :self.query_token_candidates.size(1), :] # torch.Size([bz*num_qt_candidates, 32, 768]) + query_output_to_linear, balance_loss, importance_loss, gate_load, gate = self.PromptMoE._forward_gate_text_single_token(text_cls_split, query_tokens_output) + + inputs_t5 = self.t5_proj(query_output_to_linear) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + fs_embeds, fs_atts = None, None + if self.few_shot_prob > 0 and "few_shot_samples" in samples.keys(): + fs_embeds, fs_atts = self.prepare_few_shot_embeds(samples['few_shot_samples']) + + with self.maybe_autocast(dtype=torch.bfloat16): + input_tokens = self.t5_tokenizer( + samples["llm_input"], + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + output_tokens = self.t5_output_tokenizer( + samples["text_output"], + padding="longest", + truncation=True, + max_length=self.max_output_txt_len, + return_tensors="pt", + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + targets = output_tokens.input_ids.masked_fill( + output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100 + ) + + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + if fs_embeds is not None: + inputs_embeds = torch.cat([fs_embeds, inputs_embeds], dim=1) + encoder_atts = torch.cat([fs_atts, encoder_atts], dim=1) + + outputs = self.t5_model( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + decoder_attention_mask=output_tokens.attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + if self.train_gate_save: + self._save_gate( + samples['q_input'], + samples['text_output'], + gate, + samples['image_id'], + gate_load, + os.path.join(self.gate_save_path, "train_gate.txt") + ) + + final_loss = loss + balance_loss + importance_loss + return {"loss": final_loss} + + @torch.no_grad() + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=5, + max_length=256, + min_length=1, + top_p=0.9, + repetition_penalty=1.5, + length_penalty=1.0, + num_captions=1, + temperature=1, + ): + if "prompt" in samples.keys(): + prompt = samples["prompt"] + else: + prompt = self.prompt + + image = samples["image"] + + bs = image.size(0) + + if isinstance(prompt, str): + prompt = [prompt] * bs + else: + assert len(prompt) == bs, "The number of prompts must be equal to the batch size." + + # For TextCaps + if "ocr_tokens" in samples.keys() and "{}" in prompt[0]: + prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)] + + # image embed + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + + if self.moe_position == "pre": + if self.num_qt_candidates > 1: + + with self.maybe_autocast(dtype=torch.bfloat16): + if self.embed_extract == "t5": + text_embeds = self._extract_text_embed_by_t5(samples["q_input"], samples['text_output'], image.device) + elif self.embed_extract == "blip2_pretrain": + text_embeds = self._extract_text_embed_by_qformer_pretrain_s1(samples["q_input"], image.device) + elif self.embed_extract == "random": + text_embeds = torch.randn(bs, 1, self.text_embed_size ) + select_query_tokens, _, _, gate_load, gate = self.PromptMoE._forward_gate_single_token(text_embeds) + query_tokens = select_query_tokens # torch.Size([bz, 32, 768]) + else: # back to one query token + query_tokens = self.query_tokens.expand(bs, -1, -1) + + if self.qformer_text_input: + # remove ocr tokens in q_former (for eval textvqa) + # qformer_prompt = prompt + # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt] + + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) + + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:] + + elif self.moe_position == "post": + # self.query_token_candidates : size[num_qt_candidates, 32, 768] + candi_query_tokens = self.query_token_candidates.expand(bs, -1, -1, -1).reshape(-1, self.query_token_candidates.shape[1], self.query_token_candidates.shape[2]) # size[num_qt_candidates*bz, 32, 768] + image_embeds_repeat = image_embeds.repeat_interleave(self.num_qt_candidates, dim=0) + image_atts_repeat = image_atts.repeat_interleave(self.num_qt_candidates, dim=0) + + ## Q-former Forward with candidates query tokens + if self.qformer_text_input: + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + text_Qformer_input_ids_repeat = text_Qformer.input_ids.repeat_interleave(self.num_qt_candidates, dim=0) # [bz*num_qt_candidates, batch_seq_len] + text_Qformer_attn_mask_repeat = text_Qformer.attention_mask.repeat_interleave(self.num_qt_candidates, dim=0) # [bz*num_qt_candidates, batch_seq_len] + + query_atts = torch.ones(candi_query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer_attn_mask_repeat],dim=1) + + query_output = self.Qformer.bert( + text_Qformer_input_ids_repeat, + attention_mask=Qformer_atts, + query_embeds=candi_query_tokens, + encoder_hidden_states=image_embeds_repeat, + encoder_attention_mask=image_atts_repeat, + return_dict=True, + ) # query_output.last_hidden_state size [torch.Size([bz*num_qt_candidates, 32+batch_seq_len, 768])] + query_output_to_linear = query_output.last_hidden_state[:,:self.query_token_candidates.size(1),:] + else: + query_output = self.Qformer.bert( + query_embeds=candi_query_tokens, + encoder_hidden_states=image_embeds_repeat, + encoder_attention_mask=image_atts_repeat, + return_dict=True, + ) # query_output.last_hidden_state size [torch.Size([bz*num_qt_candidates, 32, 768])] + # [(sample1, query1), (sample1, query2),..., (sample2, query1),(sample2, query2), ... , (sample_bz, query1),..., (sample_bz, queryn)] + text_cls = query_output.last_hidden_state[:,self.query_token_candidates.size(1),:] # torch.Size([bz*num_qt_candidates, 768]) + text_cls_split = text_cls.view(bs, self.num_qt_candidates, -1) # torch.Size([bz, num_qt_candidates, 768]) + query_tokens_output = query_output.last_hidden_state[:, :self.query_token_candidates.size(1), :] # torch.Size([bz*num_qt_candidates, 32, 768]) + query_output_to_linear, _, _, gate_load, gate = self.PromptMoE._forward_gate_text_single_token(text_cls_split, query_tokens_output) + + # For video data deleted : TODO + + inputs_t5 = self.t5_proj(query_output_to_linear) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + input_tokens = self.t5_tokenizer( + prompt, + padding="longest", + return_tensors="pt" + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + with self.maybe_autocast(dtype=torch.bfloat16): + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + outputs = self.t5_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_new_tokens=max_length, + min_length=min_length, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, + ) + output_text = self.t5_tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + + if self.eval_gate_save: + if "image_name" in samples.keys(): + id_lst = samples['image_name'] + elif "image_id" in samples.keys(): + id_lst = samples['image_id'] + try: + self._save_gate( + samples['q_input'], + output_text, + gate, + id_lst, + gate_load, + os.path.join(self.gate_save_path, "eval_gate.txt") + ) + except Exception as e: + print("Evaluate save gate Error:", e) + # : TODO Evaluate save gate Error: local variable 'id_lst' referenced before assignment + + return output_text + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=-1, + **kwargs + ): + if isinstance(samples["llm_input"], str): + samples["llm_input"] = [samples["llm_input"]] + + if prompt: + if prompt.count("{}") == 2: + if 'ocr_tokens' in samples: + text_input = [ + prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["llm_input"][i]) + for i in range(len(samples["llm_input"]))] + elif 'choices' in samples: + text_input = [] + for i in range(len(samples["llm_input"])): + this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])] + this_choices = " ".join(this_choices) + text_input.append(prompt.format(samples["llm_input"][i], this_choices)) + else: + text_input = [prompt.format(question) for question in samples["llm_input"]] + else: + text_input = samples["llm_input"] + + samples["prompt"] = text_input + + output_text = self.generate( + samples, + num_beams=num_beams, + max_length=max_len, + min_length=min_len, + length_penalty=length_penalty + ) + + if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]): + output_text = self._lemmatize(output_text) + + return output_text + + def predict_class( + self, + samples, + candidates, + n_segments=1, + ): + # If candidates is a list of lists, each sample has its candidates, then we need to iterate one by one + if type(candidates[0]) == list: + results = [] + + for i in range(samples["image"].size(0)): + this_sample = { + "image": samples["image"][i].unsqueeze(0), + "prompt": samples["prompt"], + } + + if "text_input" in samples.keys(): + this_sample["text_input"] = [samples["text_input"][i]] + + if 'context' in samples.keys(): + this_sample['context'] = [samples["context"][i]] + + if 'history' in samples.keys(): + this_sample['history'] = [samples["history"][i]] + + if 'caption' in samples.keys(): + this_sample['caption'] = [samples["caption"][i]] + + this_result = self._predict_class(this_sample, candidates[i], n_segments) + results.append(this_result) + + try: + results = torch.cat(results, dim=0) + except: + results = [res.tolist()[0] for res in results] + + return results + + return self._predict_class(samples, candidates, n_segments) + + def _predict_class( + self, + samples, + candidates, + n_segments=1, + ): + """ + Args: + samples (dict): A dictionary containing the following keys: + - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) + - prompt: the instruction + candidates: + (list): A list of candidate class names; + n_segments: + (int): Split the candidates into n_segments and predict one by one. This is useful when the number of candidates is too large. + Returns: + output_class: predicted class index + """ + + image = samples["image"] + prompt = samples["prompt"] + + bs = image.size(0) + + if isinstance(prompt, str): + prompt = [prompt] * bs + else: + assert len(prompt) == bs, "The number of prompts must be equal to the batch size." + + if "text_input" in samples.keys(): + if type(samples["text_input"][0]) == list: + prompt = [prompt[i].format(*samples["text_input"][i]) for i in range(len(prompt))] + else: + prompt = [prompt[i].format(samples["text_input"][i]) for i in range(len(prompt))] + + # scienceqa + if 'context' in samples.keys() and samples['context'] != '': + prompt = [f'context: {samples["context"][i]}. {prompt[i]}' for i in range(len(prompt))] + + # visual dialog + if 'history' in samples.keys() and samples['history'][0] != '': + prompt = [f'dialog history: {samples["history"][i]}\n{prompt[i]}' for i in range(len(prompt))] + + if 'caption' in samples.keys() and samples['caption'][0] != '': + prompt = [f'This image has the caption "{samples["caption"][i]}". {prompt[i]}' for i in range(len(prompt))] + + query_tokens = self.query_tokens.expand(bs, -1, -1) + if self.qformer_text_input: + text_Qformer = self.tokenizer( + prompt, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt" + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask], dim=1) + + if image.dim() == 5: + inputs_t5, atts_t5 = [], [] + for j in range(image.size(2)): + this_frame = image[:,:,j,:,:] + with self.maybe_autocast(): + frame_embeds = self.ln_vision(self.visual_encoder(this_frame)) + frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device) + + if self.qformer_text_input: + frame_query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=frame_embeds, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + else: + frame_query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=frame_embeds, + encoder_attention_mask=frame_atts, + return_dict=True, + ) + + frame_inputs_t5 = self.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:]) + frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + inputs_t5.append(frame_inputs_t5) + atts_t5.append(frame_atts_t5) + inputs_t5 = torch.cat(inputs_t5, dim=1) + atts_t5 = torch.cat(atts_t5, dim=1) + else: + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + if self.qformer_text_input: + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:]) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + input_tokens = self.t5_tokenizer( + prompt, padding="longest", return_tensors="pt" + ).to(image.device) + output_tokens = self.t5_tokenizer( + candidates, padding="longest", return_tensors="pt" + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + n_cands = len(candidates) + + with self.maybe_autocast(dtype=torch.bfloat16): + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + encoder_outputs = self.t5_model.encoder( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + ) + + all_losses = [] + for n in range(n_segments): + seg_len = n_cands // n_segments + if n == (n_segments - 1): + seg_len = n_cands - seg_len * (n_segments - 1) + + # this_encoder_outputs = copy.deepcopy(encoder_outputs) + this_encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0].clone(), + ) + + this_encoder_outputs['last_hidden_state'] = this_encoder_outputs[0].repeat_interleave(seg_len, dim=0) + this_encoder_atts = encoder_atts.repeat_interleave(seg_len, dim=0) + + start_i = n * (n_cands // n_segments) + end_i = start_i + seg_len + this_output_tokens_ids = output_tokens.input_ids[start_i:end_i].repeat(bs, 1) + this_output_tokens_atts = output_tokens.attention_mask[start_i:end_i].repeat(bs, 1) + + this_targets = this_output_tokens_ids.masked_fill(this_output_tokens_ids == self.t5_tokenizer.pad_token_id, -100) + + outputs = self.t5_model( + encoder_outputs=this_encoder_outputs, + attention_mask=this_encoder_atts, + decoder_attention_mask=this_output_tokens_atts, + return_dict=True, + labels=this_targets, + reduction="none", + ) + loss = outputs.loss + + loss = loss.reshape(bs, seg_len) + # output_class_ranks = torch.argsort(loss, dim=-1) + all_losses.append(loss) + + all_losses = torch.cat(all_losses, dim=-1) + output_class_ranks = torch.argsort(all_losses, dim=-1) + + # encoder_outputs['last_hidden_state'] = encoder_outputs[0].repeat_interleave(n_cands, dim=0) + # encoder_atts = encoder_atts.repeat_interleave(n_cands, dim=0) + # output_tokens.input_ids = output_tokens.input_ids.repeat(bs, 1) + # output_tokens.attention_mask = output_tokens.attention_mask.repeat(bs, 1) + + # # compute the LM loss for each candidate (sum logprob across all tokens) and select the highest + # targets = output_tokens.input_ids.masked_fill(output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100) + + # outputs = self.t5_model( + # encoder_outputs=encoder_outputs, + # attention_mask=encoder_atts, + # decoder_attention_mask=output_tokens.attention_mask, + # return_dict=True, + # labels=targets, + # reduction="none", + # ) + # loss = outputs.loss + + # loss = loss.reshape(bs, n_cands) + # output_class_ranks = torch.argsort(loss, dim=-1) # (bs, num_candidates) + + return output_class_ranks + + def prepare_few_shot_embeds(self, samples): + this_n_fs = random.choices( + list(range(self.num_few_shot_examples + 1)), + weights=[1 - self.few_shot_prob] + [self.few_shot_prob / self.num_few_shot_examples] * self.num_few_shot_examples + )[0] + + if this_n_fs == 0: + return None, None + + images = [] + text_input = [] + for sample in samples: + for n in range(this_n_fs): + images.append(sample['image'][n]) + text_input.append(sample['text_input'][n]) + images = torch.stack(images, dim=0) + + image = images + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( + image.device + ) + + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + if self.qformer_text_input: + text_Qformer = self.tokenizer( + text_input, + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask = Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + inputs_t5 = self.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:]) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + with self.maybe_autocast(dtype=torch.bfloat16): + input_tokens = self.t5_tokenizer( + text_input, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + if this_n_fs > 1: + encoder_atts = encoder_atts.reshape(encoder_atts.size(0) // this_n_fs, encoder_atts.size(1) * this_n_fs) + inputs_embeds = inputs_embeds.reshape(inputs_embeds.size(0) // this_n_fs, inputs_embeds.size(1) * this_n_fs, inputs_embeds.size(2)) + + return inputs_embeds, encoder_atts + + + + def _extract_text_embed_by_qformer_pretrain_s1( + self, + text_input, + device + ): + text_inputs = self.embed_extractor.tokenizer( + text_input, + padding="max_length", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(device) + text_feats = self.embed_extractor.forward_text(text_inputs) + # return text_feats.unsqueeze(1) # torch.Size([bz, 1, 768]) + + text_embeds = F.normalize(self.embed_extractor.text_proj(text_feats)) + return text_embeds.unsqueeze(1) # torch.Size([bz, 1, 256]) + + + def _extract_text_embed_by_t5( + self, + text_input, + text_output, + device + ): + bz = len(text_input) + + input_tokens = self.t5_tokenizer( + text_input, + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(device) + output_tokens = self.t5_output_tokenizer( + text_output, + padding="longest", + truncation=True, + max_length=self.max_output_txt_len, + return_tensors="pt", + ).to(device) + + targets = output_tokens.input_ids.masked_fill( + output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100 + ) + + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + + text_outputs = self.t5_model( + inputs_embeds=inputs_embeds, + attention_mask=input_tokens.attention_mask, + decoder_attention_mask=output_tokens.attention_mask, + return_dict=True, + labels=targets, + ) + last_token_embeds = list() + for i in range(bz): + seq_pos = (torch.nonzero(input_tokens.attention_mask[i]).squeeze())[-1].item() # 取最后位置 + last_token_embed = text_outputs.encoder_last_hidden_state[i][seq_pos] + last_token_embeds.append(last_token_embed.unsqueeze(0)) + text_embeds = torch.concat(last_token_embeds, dim=0).unsqueeze(1) # torch.Size([bz, 1, 4096]) + + return text_embeds + + def _save_gate(self, input_text, output_text, gate, id_lst, gate_load, gate_save_file): + tt = list() + for tinput, toutput, g, id_ in zip(input_text, output_text, gate, id_lst): + tt.append({ + 'text_input': tinput, + 'text_output': toutput, + 'gate': g.tolist(), + 'image': id_, + 'batch_gate_load': gate_load.tolist() + }) + with open(gate_save_file, "a") as f: + f.write(f"{json.dumps(tt)}\n") + + def _lemmatize(self, answers): + def apply(answer): + doc = self.lemmatizer(answer) + + words = [] + for token in doc: + if token.pos_ in ["NOUN", "VERB"]: + words.append(token.lemma_) + else: + words.append(token.text) + answer = " ".join(words) + + return answer + + return [apply(answer) for answer in answers] + + @property + def lemmatizer(self): + if self._lemmatizer is None: + try: + import spacy + + self._lemmatizer = spacy.load("en_core_web_sm") + except ImportError: + logging.error( + """ + Please install spacy and en_core_web_sm model to apply lemmatization. + python -m spacy download en_core_web_sm + OR + import spacy.cli + spacy.cli.download("en_core_web_sm") + """ + ) + exit(1) + + return self._lemmatizer + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + t5_model = cfg.get("t5_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_llm = cfg.get("freeze_llm", True) + freeze_qformer = cfg.get("freeze_qformer", False) + freeze_t5_proj = cfg.get("freeze_t5_proj", False) + + prompt = cfg.get("prompt", "") + max_txt_len = cfg.get("max_txt_len", 128) + max_output_txt_len = cfg.get("max_output_txt_len", 256) + + apply_lemmatizer = cfg.get("apply_lemmatizer", False) + + num_few_shot_examples = cfg.get("num_few_shot_examples", 0) + few_shot_prob = cfg.get("few_shot_prob", 0.0) + + qformer_text_input = cfg.get("qformer_text_input", True) + + repeat_to_init_qt_candidates= cfg.get("repeat_to_init_qt_candidates", True) + num_qt_candidates = cfg.get("num_qt_candidates", 5) + moe_topk = cfg.get("moe_topk", 2) + moe_position = cfg.get("moe_position", "pre") + embed_extract = cfg.get("embed_extract", "t5") + train_gate_save = cfg.get("train_gate_save", False) + eval_gate_save = cfg.get("eval_gate_save", False) + gate_save_path = cfg.get("gate_save_path", "") + + model = cls( + vit_model=vit_model, + img_size=img_size, + q_former_model=q_former_model, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + freeze_llm=freeze_llm, + freeze_qformer=freeze_qformer, + freeze_t5_proj=freeze_t5_proj, + num_query_token=num_query_token, + t5_model=t5_model, + prompt=prompt, + max_txt_len=max_txt_len, + max_output_txt_len=max_output_txt_len, + apply_lemmatizer=apply_lemmatizer, + num_few_shot_examples=num_few_shot_examples, + few_shot_prob=few_shot_prob, + qformer_text_input=qformer_text_input, + repeat_to_init_qt_candidates=repeat_to_init_qt_candidates, + num_qt_candidates=num_qt_candidates, + moe_topk=moe_topk, + moe_position=moe_position, + embed_extract=embed_extract, + eval_gate_save=eval_gate_save, + train_gate_save=train_gate_save, + gate_save_path=gate_save_path, + ) + + # if qformer_text_input: + # # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal) + # model.load_from_pretrained( + # url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" + # ) + + model.load_checkpoint_from_config(cfg) + + # check update params + print("Updating following parameters:") + for name, param in model.named_parameters(): + if param.requires_grad == True: + print(name) + + return model diff --git a/minigpt4/models/blip2_t5_qformer_moe.py b/minigpt4/models/blip2_t5_qformer_moe.py new file mode 100644 index 0000000..ebbc42e --- /dev/null +++ b/minigpt4/models/blip2_t5_qformer_moe.py @@ -0,0 +1,554 @@ +""" + Copyright (c) 2023, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" +import logging +import string +import random +import copy +import json +import os +import numpy as np +import re + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.cuda.amp import autocast as autocast +from transformers import T5TokenizerFast + +from minigpt4.common.registry import registry +from minigpt4.models.blip2 import Blip2Base, disabled_train +from minigpt4.models.modeling_t5 import T5Config, T5ForConditionalGeneration +from transformers.modeling_outputs import BaseModelOutput + +@registry.register_model("blip2_t5_qformer_moe_test") +class Blip2T5InstructQformerMoETest(Blip2Base): + """ + BLIP2 Instruct T5 model Qformer MoE + Supported model types: + - flant5xxl + Usage: + >>> from minigpt4.models import load_model + >>> import torch + >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + >>> model = load_model("blip2_t5_qformer_moe_test", "flant5xxl", device=device) + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "flant5xxl": "configs/models/blip2/blip2_instruct_flant5xxl_qformer_moe.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + freeze_llm=True, + freeze_qformer=False, + freeze_t5_proj=False, + num_query_token=32, + t5_model="google/flan-t5-xl", + prompt="", + max_txt_len=128, + max_output_txt_len=256, + apply_lemmatizer=False, + qformer_text_input=True, + moebert_expert_num=5, + moebert_route_method="gate-sentence", + moebert_load_balance = 0.1, + moe_topk = 1, + use_balance_loss=True, + ): + """ + apply_lemmatizer: when set to True, postprocess predict_answers() result with lemmas. + """ + super().__init__() + + self.tokenizer = self.init_tokenizer(truncation_side="left") + + print("Init BLIP2 Instruct Flant5xxl Prompt MoE") + + print('Initing & Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision + ) + # freeze vit + if freeze_vit: + for name, param in self.visual_encoder.named_parameters(): + param.requires_grad = False + self.visual_encoder = self.visual_encoder.eval() + self.visual_encoder.train = disabled_train + # freeze ln vision + for name, param in self.ln_vision.named_parameters(): + param.requires_grad = False + self.ln_vision = self.ln_vision.eval() + self.ln_vision.train = disabled_train + logging.info("freeze vision encoder") + print('Loading VIT Done') + + print('Initing MoE Q-Former') + self.Qformer, self.query_tokens = self.init_QformerMoE( + num_query_token=num_query_token, + vision_width=self.visual_encoder.num_features, + moebert_expert_num=moebert_expert_num, + moebert_route_method=moebert_route_method, + moebert_load_balance=moebert_load_balance, + moe_topk=moe_topk, + use_balance_loss=use_balance_loss, + cross_attention_freq=2 + ) + if not qformer_text_input: + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + else: + self.Qformer.resize_token_embeddings(len(self.tokenizer)) + self.Qformer.cls = None + + print('Loading T5') + self.t5_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='left') + self.t5_output_tokenizer = T5TokenizerFast.from_pretrained(t5_model, truncation_side='right') + t5_config = T5Config.from_pretrained(t5_model) + t5_config.dense_act_fn = "gelu" + self.t5_model = T5ForConditionalGeneration.from_pretrained( + t5_model, config=t5_config, use_safetensors=False + ) + # freeze t5 llm + if freeze_llm: + for name, param in self.t5_model.named_parameters(): + param.requires_grad = False + param.data = param.data.bfloat16() + print('Loading T5 Done') + + + print("Initing t5 linear projection") + self.t5_proj = nn.Linear( + self.Qformer.config.hidden_size, self.t5_model.config.hidden_size + ) + + # load BLIP2 Pretrain + print("Loading BLIP2 Parameters from :", q_former_model) + self.load_from_pretrained(url_or_filename=q_former_model) + + # init MoE Layer(init moe ffn by blip2 query ffn) + state_dict = self.Qformer.state_dict() + for name, param in self.Qformer.named_parameters(): + if "_query" in name and "experts.experts" in name: + pattern = r'\.experts\.experts\.\d+' + key_orig = re.sub(pattern, '', name) + param.data.copy_(state_dict[key_orig]) # copy state_dict[key_orig] to param + if "experts.intermediate_query" in name or "experts.output_query" in name: + key_orig = re.sub(r'experts\.', '', name) + param.data.copy_(state_dict[key_orig]) # copy state_dict[key_orig] to param + if "_query" in name and "experts" not in name: # raw ffn_query not update + param.requires_grad = False + + # freeze qformer + if freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + logging.info("freeze Qformer") + + # After loading, freeze t5_proj + if freeze_t5_proj: + for name, param in self.t5_proj.named_parameters(): + param.requires_grad = False + self.t5_proj = self.t5_proj.eval() + self.t5_proj.train = disabled_train + + self.max_txt_len = max_txt_len + self.max_output_txt_len = max_output_txt_len + self.prompt = prompt + + self._apply_lemmatizer = apply_lemmatizer + self._lemmatizer = None + + self.qformer_text_input = qformer_text_input + self.moebert_load_balance = moebert_load_balance + + def forward(self, samples): + # print('-----------------') + # print(samples["text_input"]) + # print(samples.keys()) + # print(samples) + # print(samples["text_output"]) + # print('-----------------') + # import torch + # samples = { + # 'text_input':["What is around the open window?", # n23181 + # "Is the ground blue or brown?", # n168412 + # "What color are the pants?", # n446242 + # "What is the airplane flying above?"], # n414992 + # 'llm_input':["What is around the open window?", # n23181 + # "Is the ground blue or brown?", # n168412 + # "What color are the pants?", # n446242 + # "What is the airplane flying above?"], # n414992 + # 'text_output':["drapes", + # "brown", + # "red", + # "ocean" + # ], + # 'image': torch.randn(4, 3, 224, 224).half().to(device) + # } + # model(samples) + + # samples = { + # 'text_input':["What is around the open window?"], # n414992 + # 'llm_input':["What is around the open window?"], # n414992 + # 'text_output':["drapes"], + # 'image': torch.randn(1, 3, 224, 224).to("cpu") + # } + + image = samples["image"] + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + bz = image_embeds.shape[0] + query_tokens = self.query_tokens.expand(bz, -1, -1) + + ## Q-former Forward with one query tokens + if self.qformer_text_input: + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) + + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:] + gate_loss = query_output.gate_loss + inputs_t5 = self.t5_proj(query_output_to_linear) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + with self.maybe_autocast(dtype=torch.bfloat16): + input_tokens = self.t5_tokenizer( + samples["llm_input"], + padding="longest", + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + output_tokens = self.t5_output_tokenizer( + samples["text_output"], + padding="longest", + truncation=True, + max_length=self.max_output_txt_len, + return_tensors="pt", + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + targets = output_tokens.input_ids.masked_fill( + output_tokens.input_ids == self.t5_tokenizer.pad_token_id, -100 + ) + + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + outputs = self.t5_model( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + decoder_attention_mask=output_tokens.attention_mask, + return_dict=True, + labels=targets, + ) + loss = outputs.loss + + final_loss = loss + self.moebert_load_balance * gate_loss + return {"loss": final_loss} + + @torch.no_grad() + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=5, + max_length=256, + min_length=1, + top_p=0.9, + repetition_penalty=1.5, + length_penalty=1.0, + num_captions=1, + temperature=1, + ): + if "prompt" in samples.keys(): + prompt = samples["prompt"] + else: + prompt = self.prompt + + image = samples["image"] + + bs = image.size(0) + + if isinstance(prompt, str): + prompt = [prompt] * bs + else: + assert len(prompt) == bs, "The number of prompts must be equal to the batch size." + + # For TextCaps + if "ocr_tokens" in samples.keys() and "{}" in prompt[0]: + prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)] + + # image embed + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + query_tokens = self.query_tokens.expand(bs, -1, -1) + + if self.qformer_text_input: + # remove ocr tokens in q_former (for eval textvqa) + # qformer_prompt = prompt + # qformer_prompt = ['Question: ' + qp.split(' Question: ')[1] for qp in qformer_prompt] + + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) + + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:] + + inputs_t5 = self.t5_proj(query_output_to_linear) + atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) + + input_tokens = self.t5_tokenizer( + prompt, + padding="longest", + return_tensors="pt" + ).to(image.device) + + encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) + + with self.maybe_autocast(dtype=torch.bfloat16): + inputs_embeds = self.t5_model.encoder.embed_tokens(input_tokens.input_ids) + inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) + + outputs = self.t5_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=encoder_atts, + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_new_tokens=max_length, + min_length=min_length, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, + ) + output_text = self.t5_tokenizer.batch_decode( + outputs, skip_special_tokens=True + ) + + return output_text + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=-1, + **kwargs + ): + if isinstance(samples["llm_input"], str): + samples["llm_input"] = [samples["llm_input"]] + + if prompt: + if prompt.count("{}") == 2: + if 'ocr_tokens' in samples: + text_input = [ + prompt.format(', '.join(samples['ocr_tokens'][i][:30]), samples["llm_input"][i]) + for i in range(len(samples["llm_input"]))] + elif 'choices' in samples: + text_input = [] + for i in range(len(samples["llm_input"])): + this_choices = [f"({string.ascii_lowercase[j]}) {ch}" for j, ch in enumerate(samples["choices"][i])] + this_choices = " ".join(this_choices) + text_input.append(prompt.format(samples["llm_input"][i], this_choices)) + else: + text_input = [prompt.format(question) for question in samples["llm_input"]] + else: + text_input = samples["llm_input"] + + samples["prompt"] = text_input + + output_text = self.generate( + samples, + num_beams=num_beams, + max_length=max_len, + min_length=min_len, + length_penalty=length_penalty + ) + + if self._apply_lemmatizer or ("apply_lemmatizer" in samples.keys() and samples["apply_lemmatizer"]): + output_text = self._lemmatize(output_text) + + return output_text + + def _lemmatize(self, answers): + def apply(answer): + doc = self.lemmatizer(answer) + + words = [] + for token in doc: + if token.pos_ in ["NOUN", "VERB"]: + words.append(token.lemma_) + else: + words.append(token.text) + answer = " ".join(words) + + return answer + + return [apply(answer) for answer in answers] + + @property + def lemmatizer(self): + if self._lemmatizer is None: + try: + import spacy + + self._lemmatizer = spacy.load("en_core_web_sm") + except ImportError: + logging.error( + """ + Please install spacy and en_core_web_sm model to apply lemmatization. + python -m spacy download en_core_web_sm + OR + import spacy.cli + spacy.cli.download("en_core_web_sm") + """ + ) + exit(1) + + return self._lemmatizer + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + t5_model = cfg.get("t5_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_llm = cfg.get("freeze_llm", True) + freeze_qformer = cfg.get("freeze_qformer", False) + freeze_t5_proj = cfg.get("freeze_t5_proj", False) + + prompt = cfg.get("prompt", "") + max_txt_len = cfg.get("max_txt_len", 128) + max_output_txt_len = cfg.get("max_output_txt_len", 256) + apply_lemmatizer = cfg.get("apply_lemmatizer", False) + + qformer_text_input = cfg.get("qformer_text_input", True) + + moebert_expert_num = cfg.get("moebert_expert_num", 5) + moebert_route_method = cfg.get("moebert_route_method", "gate-sentence") + moebert_load_balance = cfg.get("moebert_load_balance", 0.1) + moe_topk = cfg.get("moe_topk", 1) + use_balance_loss = cfg.get("use_balance_loss", True) + + model = cls( + vit_model=vit_model, + img_size=img_size, + q_former_model=q_former_model, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + freeze_llm=freeze_llm, + freeze_qformer=freeze_qformer, + freeze_t5_proj=freeze_t5_proj, + num_query_token=num_query_token, + t5_model=t5_model, + prompt=prompt, + max_txt_len=max_txt_len, + max_output_txt_len=max_output_txt_len, + apply_lemmatizer=apply_lemmatizer, + qformer_text_input=qformer_text_input, + moebert_expert_num=moebert_expert_num, + moebert_route_method=moebert_route_method, + moebert_load_balance=moebert_load_balance, + moe_topk=moe_topk, + use_balance_loss=use_balance_loss, + ) + + # if qformer_text_input: + # # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal) + # model.load_from_pretrained( + # url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" + # ) + + model.load_checkpoint_from_config(cfg) + + # check update params + print("Updating following parameters:") + for name, param in model.named_parameters(): + if param.requires_grad == True: + print(name) + + # layer self attention: 2,363,904 + # layer pure ffn : 4,723,968 + # layer expert ffn : 4,723,968 + # layer cross attention: 3,346,944 + + return model diff --git a/minigpt4/models/blip2_vicuna_instruct.py b/minigpt4/models/blip2_vicuna_instruct.py new file mode 100644 index 0000000..a2befec --- /dev/null +++ b/minigpt4/models/blip2_vicuna_instruct.py @@ -0,0 +1,719 @@ +""" +Requires Transformer 4.28 and above, implementation may change according the Llama implementation +""" +import logging +import string +from packaging import version +import re + +import torch +from torch.cuda.amp import autocast as autocast +import torch.nn as nn + +import transformers + +from minigpt4.common.registry import registry +from minigpt4.models.blip2 import Blip2Base, disabled_train + +@registry.register_model("blip2_vicuna_instruct") +class Blip2VicunaInstruct(Blip2Base): + """ + BLIP2 Vicuna model. + Supported model types: + - vicuna7b + - vicuna13b + Usage: + >>> from minigpt4.models import load_model + >>> import torch + >>> device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + >>> model = load_model("blip2_vicuna_instruct", "vicuna7b_qfmoe_route_uni", device=device) + """ + + PRETRAINED_MODEL_CONFIG_DICT = { + "vicuna7b_instruct": "configs/models/blip2/blip2_instruct_vicuna7b.yaml", + "vicuna7b_pretrain": "configs/models/blip2/blip2_pretrain_vicuna7b.yaml", + "vicuna7b_qfmoe_post": "configs/models/blip2/blip2_qformer_moe_post_vicuna7b.yaml", + "vicuna7b_qfmoe_route": "configs/models/blip2/blip2_pretrain_vicuna7b_route_moe.yaml", + "vicuna7b_qfmoe_route_uni": "configs/models/blip2/blip2_pretrain_vicuna7b_route_moe_universal.yaml", + } + + def __init__( + self, + vit_model="eva_clip_g", + q_former_model="", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=True, + freeze_llm=True, + freeze_qformer=False, + freeze_proj=False, + num_query_token=32, + llm_model="", + prompt="", + max_txt_len=128, + max_output_txt_len=256, + lora_r=0, # lora_r means lora is not used + lora_target_modules=["q_proj", "v_proj"], + lora_alpha=16, + lora_dropout=0.05, + qformer_text_input=True, + general_version='base', + moebert_num_beams=2, + moebert_expert_num=5, + moebert_route_method="gate-sentence", + moebert_load_balance = 0.1, + moe_topk = 1, + use_balance_loss = True, + moe_weight_type = "l2_norm", + gate_save_path = None, + bal_loss_decay_epoch = 3, + ln_position = "out", + ): + super().__init__() + transformers_version = version.parse(transformers.__version__) + assert transformers_version >= version.parse("4.28"), "BLIP-2 Vicuna requires transformers>=4.28" + from transformers import LlamaTokenizer + from minigpt4.models.modeling_llama import LlamaForCausalLM + + self.tokenizer = self.init_tokenizer(truncation_side="left") + + print('Initing & Loading VIT') + self.visual_encoder, self.ln_vision = self.init_vision_encoder( + vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit, + ) + + print('Initing & Loading Qformer') + if general_version in ['naive_moe', 'route_moe', 'uni_route_moe']: + if general_version == 'naive_moe': + self.Qformer, self.query_tokens = self.init_QformerMoE( + num_query_token=num_query_token, + vision_width=self.visual_encoder.num_features, + moebert_expert_num=moebert_expert_num, + moebert_route_method=moebert_route_method, + moebert_load_balance=moebert_load_balance, + moe_topk=moe_topk, + use_balance_loss=use_balance_loss, + moe_weight_type=moe_weight_type, + cross_attention_freq=2, + ln_position=ln_position, + ) + elif general_version == 'route_moe': + self.Qformer, self.query_tokens = self.init_RouteMoEQformer( + num_query_token=num_query_token, + vision_width=self.visual_encoder.num_features, + moebert_expert_num=moebert_expert_num, + moebert_num_beams=moebert_num_beams, + route_method=moebert_route_method, + moe_weight_type=moe_weight_type, + cross_attention_freq=2, + ln_position=ln_position, + ) + elif general_version == 'uni_route_moe': + self.Qformer, self.query_tokens = self.init_RouteMoEQformerUni( + num_query_token=num_query_token, + vision_width=self.visual_encoder.num_features, + moebert_expert_num=moebert_expert_num, + moebert_num_beams=moebert_num_beams, + route_method=moebert_route_method, + moe_weight_type=moe_weight_type, + cross_attention_freq=2, + ln_position=ln_position, + ) + + elif general_version == 'base': + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, self.visual_encoder.num_features + ) + + if not qformer_text_input: + self.Qformer.bert.embeddings.word_embeddings = None + self.Qformer.bert.embeddings.position_embeddings = None + for layer in self.Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + else: + self.Qformer.resize_token_embeddings(len(self.tokenizer)) + self.Qformer.cls = None + + print("Loading LLM") + self.llm_model, self.llm_tokenizer = self.init_llm( + llama_model_path=llm_model, + freeze_llm=freeze_llm, + lora_r=lora_r, + lora_target_modules=lora_target_modules, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + ) + + self.llm_proj = nn.Linear( + self.Qformer.config.hidden_size, self.llm_model.config.hidden_size + ) + + if qformer_text_input: + # Hard-coded to load from BLIP-2 stage-1 pre-trained model( to init ffn but not ideal) + self.load_from_pretrained( + url_or_filename="/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_pretrained/blip2_pretrained.pth", + num_query_token=num_query_token + ) + + if general_version not in ['base']: + # load blip2_vicuna_pretrain to init query_ffn + self.load_from_pretrained( + url_or_filename=q_former_model, + num_query_token=num_query_token + ) + # init MoE Layer(init moe ffn by blip2 query ffn) + self.adjust_param_qformer() + + + # freeze qformer + if freeze_qformer: + for name, param in self.Qformer.named_parameters(): + param.requires_grad = False + self.Qformer = self.Qformer.eval() + self.Qformer.train = disabled_train + logging.info("freeze Qformer") + + # After loading, freeze llm_proj + if freeze_proj: + for name, param in self.llm_proj.named_parameters(): + param.requires_grad = False + self.llm_proj = self.llm_proj.eval() + self.llm_proj.train = disabled_train + + self.max_txt_len = max_txt_len + self.max_output_txt_len = max_output_txt_len + self.prompt = prompt + prompt_tokens = self.llm_tokenizer(self.prompt, return_tensors="pt") + self.prompt_length = prompt_tokens.attention_mask.sum(1) + + self.qformer_text_input = qformer_text_input + self.general_version = general_version + self.moebert_load_balance = moebert_load_balance + self.moebert_num_beams = moebert_num_beams + + self.gate_save_path = gate_save_path + self.bal_loss_decay_epoch = bal_loss_decay_epoch + + def adjust_param_qformer(self): + # init MoE Layer(init moe ffn by blip2 query ffn) + state_dict = self.Qformer.state_dict() + for name, param in self.Qformer.named_parameters(): + if "_query" in name and "experts.experts" in name: + pattern = r'\.experts\.experts\.\d+' + key_orig = re.sub(pattern, '', name) + param.data.copy_(state_dict[key_orig]) # copy state_dict[key_orig] to param + if "experts.intermediate_query" in name or "experts.output_query" in name: + key_orig = re.sub(r'experts\.', '', name) + param.data.copy_(state_dict[key_orig]) # copy state_dict[key_orig] to param + if "_query" in name and "experts" not in name: # raw ffn_query not update + param.requires_grad = False + + ln_pattern = r"bert\.encoder\.layer\.\d+\.expert_ln\.(weight|bias)" + if re.match(ln_pattern, name): + key_orig = re.sub('expert_ln', 'output_query.LayerNorm', name) + param.data.copy_(state_dict[key_orig]) + d1_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense1\.(weight|bias)" + if re.match(d1_pattern, name): + key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense1', 'intermediate_query.dense', name) + param.data.copy_(state_dict[key_orig]) + d2_pattern = r"bert\.encoder\.layer\.(\d+)\.experts(\.|\.experts\.\d+\.)dense2\.(weight|bias)" + if re.match(d2_pattern, name): + key_orig = re.sub(r'experts(\.|\.experts\.\d+\.)dense2', 'output_query.dense', name) + param.data.copy_(state_dict[key_orig]) + + def concat_text_input_output(self, input_ids, input_atts, output_ids, output_atts): + input_part_targets_len = [] + llm_tokens = {"input_ids": [], "attention_mask": []} + for i in range(input_ids.size(0)): + this_input_ones = input_atts[i].sum() + input_part_targets_len.append(this_input_ones) + llm_tokens['input_ids'].append( + torch.cat([ + input_ids[i][:this_input_ones], + output_ids[i][1:], + input_ids[i][this_input_ones:] + ]) + ) + llm_tokens['attention_mask'].append( + torch.cat([ + input_atts[i][:this_input_ones], + output_atts[i][1:], + input_atts[i][this_input_ones:] + ]) + ) + llm_tokens['input_ids'] = torch.stack(llm_tokens['input_ids']) + llm_tokens['attention_mask'] = torch.stack(llm_tokens['attention_mask']) + return llm_tokens, input_part_targets_len + + def forward(self, samples): + # print('-----------------') + # print(samples["text_input"]) + # print(samples["text_output"]) + # print('-----------------') + # import pdb;pdb.set_trace() # 0107test + image = samples["image"] + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + bs = image.size(0) + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + + if self.qformer_text_input: + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask],dim=1) + + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + output_hidden_states=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + output_hidden_states=True, + ) + # import pdb; pdb.set_trace()# 0107test + query_output_to_linear = query_output.last_hidden_state[:,:query_tokens.size(1),:] + + if self.general_version not in ['base']: + gate_loss = query_output.gate_loss # only available in QformerMoE + + if self.gate_save_path != None: + all_hidden_states = query_output.hidden_states + # prob_gate_normalized = query_output.gate_loads + beam_scores = query_output.beam_scores + expert_route = query_output.expert_route + + gate_route = list() + import numpy as np + import json + import os + try: + for i in range(len(samples['image_id'])): + image_id = samples['image_id'][i] + gate_route.append({ + 'iters': samples['iters'], + 'image_id':image_id, + 'q_input': samples['q_input'][i], + 'text_output': samples['text_output'][i], + 'beam_scores': beam_scores[i].tolist(), + 'expert_route': expert_route[i].tolist(), + # 'gate_route_11': prob_gate_normalized[10][i].tolist(), + # 'gate_route_9': prob_gate_normalized[8][i].tolist(), + # 'gate_route_7': prob_gate_normalized[6][i].tolist(), + # 'gate_route_5': prob_gate_normalized[4][i].tolist(), + # 'gate_route_3': prob_gate_normalized[2][i].tolist(), + # 'gate_route_1': prob_gate_normalized[0][i].tolist(), + }) + # for layer in [6,8,10]: + # layer_data = all_hidden_states[layer] + # file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy') + # x = layer_data.data.cpu().numpy() + # np.save(file_path,x) + + with open(os.path.join(self.gate_save_path, 'train_save_beam.json'),'a+') as f: + f.write(f"{json.dumps(gate_route)}\n") + except Exception as e: + print("Gate Save Error....") + print(e) + + inputs_llm = self.llm_proj(query_output_to_linear) + atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device) + + self.llm_tokenizer.padding_side = "right" + self.llm_tokenizer.truncation_side = 'left' + text_input_tokens = self.llm_tokenizer( + samples['llm_input'], + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_txt_len, + ).to(image.device) + + self.llm_tokenizer.truncation_side = 'right' + text_output_tokens = self.llm_tokenizer( + [t + self.llm_tokenizer.eos_token for t in samples['text_output']], + return_tensors="pt", + padding="longest", + truncation=True, + max_length=self.max_output_txt_len, + ).to(image.device) + + llm_tokens, input_part_targets_len = self.concat_text_input_output( + text_input_tokens.input_ids, + text_input_tokens.attention_mask, + text_output_tokens.input_ids, + text_output_tokens.attention_mask, + ) + + # do not apply loss to the padding + targets = llm_tokens['input_ids'].masked_fill( + llm_tokens['input_ids'] == self.llm_tokenizer.pad_token_id, -100 + ) + + # do not apply loss to the text input (i.e., instruction) + for i, l in enumerate(input_part_targets_len): + targets[i][:l] = -100 + + # do not apply loss to the query tokens + empty_targets = ( + torch.ones(atts_llm.size(), dtype=torch.long).to(image.device).fill_(-100) + ) + targets = torch.cat([empty_targets, targets], dim=1) + + inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens['input_ids']) + inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1) + attention_mask = torch.cat([atts_llm, llm_tokens['attention_mask']], dim=1) + + with self.maybe_autocast(): + outputs = self.llm_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + return_dict=True, + labels=targets, + ) + + if self.general_version not in ['base']: + if samples['epoch'] > self.bal_loss_decay_epoch: + loss = outputs.loss + else: + loss = outputs.loss + self.moebert_load_balance * gate_loss + else: + loss = outputs.loss + + return {"loss": loss} + + @torch.no_grad() + def generate( + self, + samples, + use_nucleus_sampling=False, + num_beams=5, + max_length=256, + min_length=1, + top_p=0.9, + repetition_penalty=1.5, + length_penalty=1, + num_captions=1, + temperature=1, + ): + self.llm_tokenizer.padding_side = "left" + + image = samples["image"] + bs = image.size(0) + + query_tokens = self.query_tokens.expand(bs, -1, -1) + if self.qformer_text_input: + text_Qformer = self.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=self.max_txt_len, + return_tensors="pt", + ).to(image.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) + Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1) + + with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + + if self.qformer_text_input: + query_output = self.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + output_hidden_states=True, + ) + else: + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + output_hidden_states=True, + ) + + # import pdb; pdb.set_trace() + + if self.gate_save_path != None: + if "qformer_moe_route" in self.gate_save_path: + self.gate_save(samples, query_output, mode="route") + else: + self.gate_save(samples, query_output, mode="naive") + + inputs_llm = self.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:]) + atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device) + + llm_tokens = self.llm_tokenizer( + samples['llm_input'], + padding="longest", + return_tensors="pt" + ).to(image.device) + + with self.maybe_autocast(): + inputs_embeds = self.llm_model.get_input_embeddings()(llm_tokens.input_ids) + + # if self.gate_save_path != None: + # self.save_embeddings(samples, inputs_llm) + + inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1) + attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1) + + outputs = self.llm_model.generate( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_length=max_length, + min_length=min_length, + # eos_token_id=self.eos_token_id, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, + ) + + # outputs[outputs == 0] = 2 # convert output id 0 to 2 (eos_token_id) + output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) + output_text = [text.strip() for text in output_text] + + return output_text + + def predict_answers( + self, + samples, + num_beams=5, + inference_method="generate", + max_len=10, + min_len=1, + num_ans_candidates=128, + answer_list=None, + prompt="", + length_penalty=0, + **kwargs + ): + if isinstance(samples["llm_input"], str): + samples["llm_input"] = [samples["llm_input"]] + + output_text = self.generate( + samples, + num_beams=num_beams, + max_length=max_len, + min_length=min_len, + length_penalty=length_penalty + ) + return output_text + + def save_embeddings(self, samples, inputs_llm): + import numpy as np + import os + import json + try: + path = os.path.join(self.gate_save_path, "embedding") + for i in range(len(samples['image_id'])): + np.save(os.path.join(path, f"{samples['image_id'][i]}inputs_llm.npy"), inputs_llm[i].cpu().numpy) + np.save(os.path.join(path, "llm_embedding.npy"), self.llm_model.get_input_embeddings().weight.cpu().numpy) + samples_copy = samples.copy() + samples_copy.pop('image', None) + with open(os.path.join(path, '{}_test_samples.json'.format(samples['image_id'][0])),'a+') as f: + f.write(f"{json.dumps(samples_copy)}\n") + except Exception as e: + print("Embedding Save Error....") + print(e) + + def gate_save(self, samples, query_output, mode="naive"): + """ + mode: naive/route + """ + import numpy as np + import json + import os + + if mode == "naive": + all_hidden_states = query_output.hidden_states + prob_gate_normalized = query_output.gate_loads + + gate_route = list() + try: + for i in range(len(samples['image_id'])): + source = samples['source'][i] + if source in ['gqa']: + image_id = samples['image_id'][i].split('.')[0] + else: + image_id = samples['image_id'][i].split('/')[-1].split('.')[0] + gate_route.append({ + 'source': source, + 'image_id':image_id, + 'q_input': samples['q_input'][i], + 'gate_route_11': prob_gate_normalized[11][i].tolist(), + 'gate_route_10': prob_gate_normalized[10][i].tolist(), + 'gate_route_9': prob_gate_normalized[9][i].tolist(), + 'gate_route_8': prob_gate_normalized[8][i].tolist(), + 'gate_route_7': prob_gate_normalized[7][i].tolist(), + 'gate_route_6': prob_gate_normalized[6][i].tolist(), + }) + # Naive + for layer in [6,7,8,9,10,11]: + layer_data = all_hidden_states[layer][i, :, :] + file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy') + x = layer_data.data.cpu().numpy() + np.save(file_path,x) # 大功告成 + except Exception as e: + print("Naive Gate Save Error....") + print(e) + + elif mode == "route": + all_hidden_states = query_output.hidden_states + beam_scores = query_output.beam_scores + expert_route = query_output.expert_route + + gate_route = list() + try: + for i in range(len(samples['image_id'])): + source = samples['source'][i] + if source in ['gqa']: + image_id = samples['image_id'][i].split('.')[0] + else: + image_id = samples['image_id'][i].split('/')[-1].split('.')[0] + gate_route.append({ + 'source': source, + 'image_id':image_id, + 'q_input': samples['q_input'][i], + 'beam_scores': beam_scores[i].tolist(), + 'expert_route': expert_route[i].tolist(), + }) + if self.general_version=='route_moe': + # Route + for layer in [6,7,8,9,10,11]: + if layer in [6,11]: + layer_data = all_hidden_states[layer][i, :, :] + else: + layer_data = all_hidden_states[layer][i*self.moebert_num_beams, :, :] + file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy') + x = layer_data.data.cpu().numpy() + np.save(file_path,x) # 大功告成 + elif self.general_version=='uni_route_moe': + import pdb;pdb.set_trace() + for layer in [6,7,8,9,10,11]: + if layer in [6,11]: + layer_data = all_hidden_states[layer][i, :, :] + else: + layer_data = all_hidden_states[layer][i*self.moebert_num_beams, :, :] + file_path = os.path.join(self.gate_save_path, f'{image_id}_{str(layer)}.npy') + x = layer_data.data.cpu().numpy() + np.save(file_path,x) # 大功告成 + except Exception as e: + print("Route Gate Save Error....") + print(e) + + with open(os.path.join(self.gate_save_path, 'generate_save_beam.json'),'a+') as f: + f.write(f"{json.dumps(gate_route)}\n") + + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llm_model = cfg.get("llm_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + freeze_llm = cfg.get("freeze_llm", True) + freeze_qformer = cfg.get("freeze_qformer", False) + freeze_proj = cfg.get("freeze_proj", False) + + prompt = cfg.get("prompt", "") + max_txt_len = cfg.get("max_txt_len", 128) + max_output_txt_len = cfg.get("max_output_txt_len", 256) + + lora_r = cfg.get("lora_r", 64) + lora_alpha = cfg.get("lora_alpha", 16) + + qformer_text_input = cfg.get("qformer_text_input", True) + + general_version = cfg.get("general_version", False) + moebert_num_beams = cfg.get("moebert_num_beams", 2) + moebert_expert_num = cfg.get("moebert_expert_num", 5) + moebert_route_method = cfg.get("moebert_route_method", "gate-sentence") + moebert_load_balance = cfg.get("moebert_load_balance", 0.1) + moe_topk = cfg.get("moe_topk", 1) + use_balance_loss = cfg.get("use_balance_loss", True) + moe_weight_type = cfg.get("moe_weight_type",'l2_norm') + gate_save_path = cfg.get("gate_save_path", None) + bal_loss_decay_epoch = cfg.get("bal_loss_decay_epoch", 3) + ln_position = cfg.get("ln_position","out") + + model = cls( + vit_model=vit_model, + img_size=img_size, + q_former_model=q_former_model, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + freeze_llm=freeze_llm, + freeze_qformer=freeze_qformer, + freeze_proj=freeze_proj, + num_query_token=num_query_token, + llm_model=llm_model, + prompt=prompt, + max_txt_len=max_txt_len, + max_output_txt_len=max_output_txt_len, + lora_r=lora_r, # lora_r means lora is not used + lora_alpha=lora_alpha, + qformer_text_input=qformer_text_input, + general_version=general_version, + moebert_num_beams=moebert_num_beams, + moebert_expert_num=moebert_expert_num, + moebert_route_method=moebert_route_method, + moebert_load_balance=moebert_load_balance, + moe_topk=moe_topk, + use_balance_loss=use_balance_loss, + moe_weight_type=moe_weight_type, + gate_save_path=gate_save_path, + bal_loss_decay_epoch=bal_loss_decay_epoch, + ln_position=ln_position, + ) + + # if qformer_text_input: + # # Hard-coded to load from BLIP-2 stage-1 pre-trained model (not ideal) + # model.load_from_pretrained( + # url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained.pth" + # ) + + model.load_checkpoint_from_config(cfg) + + # check update params + print("Updating following parameters:") + for name, param in model.named_parameters(): + if param.requires_grad == True: + print(name) + # [name for name, param in model.named_parameters() if (param.requires_grad == False and 'Qformer' in name and 'intermediate_query' in name)] + # import pdb; pdb.set_trace()# 0107test + return model diff --git a/minigpt4/models/eva_vit.py b/minigpt4/models/eva_vit.py index 7fcc63a..a918f84 100644 --- a/minigpt4/models/eva_vit.py +++ b/minigpt4/models/eva_vit.py @@ -60,7 +60,7 @@ class Mlp(nn.Module): x = self.drop(x) return x - +# from visualizer import get_local # attention_visualization class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., @@ -115,6 +115,8 @@ class Attention(nn.Module): self.proj = nn.Linear(all_head_dim, dim) self.proj_drop = nn.Dropout(proj_drop) + # attention_visualization + # @get_local('attn') def forward(self, x, rel_pos_bias=None): B, N, C = x.shape qkv_bias = None @@ -369,7 +371,20 @@ class VisionTransformer(nn.Module): return features - + def get_num_layer(self, var_name=""): + if var_name in ("cls_token", "mask_token", "pos_embed"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("rel_pos_bias"): + return len(self.blocks) - 1 + elif var_name.startswith("blocks"): + layer_id = int(var_name.split('.')[1]) + return layer_id + 1 + else: + return len(self.blocks) + + def interpolate_pos_embed(model, checkpoint_model): if 'pos_embed' in checkpoint_model: pos_embed_checkpoint = checkpoint_model['pos_embed'].float() @@ -426,10 +441,14 @@ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precis norm_layer=partial(nn.LayerNorm, eps=1e-6), use_checkpoint=use_checkpoint, ) - url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" - cached_file = download_cached_file( - url, check_hash=False, progress=True - ) + # url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" + # cached_file = download_cached_file( + # url, check_hash=False, progress=True + # ) + + # cached_file = '/mnt/pfs-guan-ssai/nlu/dingyifeng/checkpoints/BLIP2/eva_vit_g.pth' + cached_file = '/mnt/pfs-guan-ssai/nlu/wanghanzi/models/eva_vit/eva_vit_g.pth' + state_dict = torch.load(cached_file, map_location="cpu") interpolate_pos_embed(model,state_dict) diff --git a/minigpt4/models/minigpt4.py b/minigpt4/models/minigpt4.py index a2e4798..0484def 100644 --- a/minigpt4/models/minigpt4.py +++ b/minigpt4/models/minigpt4.py @@ -18,14 +18,14 @@ class MiniGPT4(MiniGPTBase): """ PRETRAINED_MODEL_CONFIG_DICT = { - "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml", - "pretrain_llama2": "configs/models/minigpt4_llama2.yaml", + "pretrain_vicuna0": "configs/models/minigpt/minigpt4_vicuna0.yaml", + "pretrain_llama2": "configs/models/minigpt/minigpt4_llama2.yaml", } def __init__( self, vit_model="eva_clip_g", - q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + q_former_model="/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth", img_size=224, drop_path_rate=0, use_grad_checkpoint=False, @@ -86,7 +86,7 @@ class MiniGPT4(MiniGPTBase): @classmethod def init_Qformer(cls, num_query_token, vision_width, freeze): - encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") encoder_config.encoder_width = vision_width # insert cross-attention layer every other block encoder_config.add_cross_attention = True @@ -147,7 +147,7 @@ class MiniGPT4(MiniGPTBase): @classmethod def from_config(cls, cfg): vit_model = cfg.get("vit_model", "eva_clip_g") - q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + q_former_model = cfg.get("q_former_model", "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth") img_size = cfg.get("image_size") num_query_token = cfg.get("num_query_token") llama_model = cfg.get("llama_model") diff --git a/minigpt4/models/minigpt_base.py b/minigpt4/models/minigpt_base.py index cd051ec..4208b10 100644 --- a/minigpt4/models/minigpt_base.py +++ b/minigpt4/models/minigpt_base.py @@ -309,6 +309,7 @@ class MiniGPTBase(BaseModel): def embed_tokens(self, token_ids): if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model + # print(self.llama_model.base_model.device()) embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) else: embeds = self.llama_model.base_model.embed_tokens(token_ids) diff --git a/minigpt4/models/minigpt_v2.py b/minigpt4/models/minigpt_v2.py index a046b0b..ba59d43 100644 --- a/minigpt4/models/minigpt_v2.py +++ b/minigpt4/models/minigpt_v2.py @@ -18,7 +18,7 @@ class MiniGPTv2(MiniGPTBase): """ PRETRAINED_MODEL_CONFIG_DICT = { - "pretrain": "configs/models/minigpt_v2.yaml", + "pretrain": "configs/models/minigpt/minigpt_v2.yaml", } def __init__( diff --git a/minigpt4/models/modeling_llama.py b/minigpt4/models/modeling_llama.py index 5d59a53..1c83372 100644 --- a/minigpt4/models/modeling_llama.py +++ b/minigpt4/models/modeling_llama.py @@ -1,17 +1,641 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" import math from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F -from torch.nn import CrossEntropyLoss +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss -from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC -from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.models.llama.configuration_llama import LlamaConfig -class LlamaForCausalLM(LlamaForCausalLMOrig): +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +class LlamaRotaryEmbedding(torch.nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] + gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) + cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.max_position_embeddings + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttention(config=config) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) @@ -46,13 +670,13 @@ class LlamaForCausalLM(LlamaForCausalLMOrig): >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -75,13 +699,7 @@ class LlamaForCausalLM(LlamaForCausalLMOrig): ) hidden_states = outputs[0] - if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1: - lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) - logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] - logits = torch.cat(logits, dim=-1) - else: - logits = self.lm_head(hidden_states) - logits = logits.float() + logits = self.lm_head(hidden_states) loss = None if labels is not None: @@ -96,6 +714,7 @@ class LlamaForCausalLM(LlamaForCausalLMOrig): shift_labels = shift_labels.to(shift_logits.device) loss = loss_fct(shift_logits, shift_labels) if reduction == "none": + # loss = loss.view(logits.size(0), -1).sum(1) loss = loss.view(logits.size(0), -1).mean(1) if not return_dict: @@ -109,3 +728,162 @@ class LlamaForCausalLM(LlamaForCausalLMOrig): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) \ No newline at end of file diff --git a/minigpt4/models/modeling_llama_minigpt4.py b/minigpt4/models/modeling_llama_minigpt4.py new file mode 100644 index 0000000..5d59a53 --- /dev/null +++ b/minigpt4/models/modeling_llama_minigpt4.py @@ -0,0 +1,111 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss + +from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import LLAMA_INPUTS_DOCSTRING, _CONFIG_FOR_DOC +from transformers.models.llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLMOrig + + +class LlamaForCausalLM(LlamaForCausalLMOrig): + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if hasattr(self.config, 'pretraining_tp') and self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(reduction=reduction) + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + if reduction == "none": + loss = loss.view(logits.size(0), -1).mean(1) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/minigpt4/models/modeling_t5.py b/minigpt4/models/modeling_t5.py new file mode 100644 index 0000000..10e4d56 --- /dev/null +++ b/minigpt4/models/modeling_t5.py @@ -0,0 +1,2063 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch T5 model.""" + + +import copy +import math +import os +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithPastAndCrossAttentions, + Seq2SeqLMOutput, + Seq2SeqModelOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.utils import ( + DUMMY_INPUTS, + DUMMY_MASK, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_torch_fx_proxy, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.configuration_t5 import T5Config + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "T5Config" +_TOKENIZER_FOR_DOC = "T5Tokenizer" +_CHECKPOINT_FOR_DOC = "t5-small" + +#################################################### +# This dict contains ids and associated url +# for the pretrained weights provided with the models +#################################################### +T5_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "t5-small", + "t5-base", + "t5-large", + "t5-3b", + "t5-11b", + # See all T5 models at https://huggingface.co/models?filter=t5 +] + + +#################################################### +# This is a conversion method from TF 1.0 to PyTorch +# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28 +#################################################### +def load_tf_weights_in_t5(model, config, tf_checkpoint_path): + """Load tf checkpoints in a pytorch model.""" + try: + import re + + import numpy as np + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(tf_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + tf_weights = {} + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + tf_weights[name] = array + + for txt_name in names: + name = txt_name.split("/") + # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v + # which are not required for using pretrained model + if any( + n + in [ + "adam_v", + "adam_m", + "AdamWeightDecayOptimizer", + "AdamWeightDecayOptimizer_1", + "global_step", + ] + for n in name + ): + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + if "_slot_" in name[-1]: + logger.info(f"Skipping {'/'.join(name)}") + tf_weights.pop(txt_name, None) + continue + pointer = model + array = tf_weights[txt_name] + + for m_name in name: + if re.fullmatch(r"[A-Za-z]+_\d+", m_name): + scope_names = re.split(r"_(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + elif scope_names[0] == "self_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[0] + elif scope_names[0] == "enc_dec_attention": + pointer = getattr(pointer, "layer") + pointer = pointer[1] + elif scope_names[0] == "dense_relu_dense": + pointer = getattr(pointer, "layer") + pointer = pointer[2] + elif scope_names[0] == "rms_norm": + if hasattr(pointer, "layer_norm"): + pointer = getattr(pointer, "layer_norm") + elif hasattr(pointer, "final_layer_norm"): + pointer = getattr(pointer, "final_layer_norm") + elif scope_names[0] == "scale": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "output_bias" or scope_names[0] == "beta": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "squad": + pointer = getattr(pointer, "classifier") + elif scope_names[0] == "decoder" and name[1] == "logits": + continue + elif scope_names[0] == "logits": + pointer = getattr(pointer, "lm_head") + elif ( + scope_names[0] == "wi" + and len(scope_names) > 1 + and scope_names[1].isdigit() + ): + pointer = getattr(pointer, f"wi_{scope_names[1]}") + continue + else: + try: + pointer = getattr(pointer, scope_names[0]) + except AttributeError: + logger.info(f"Skipping {'/'.join(name)}") + continue + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + if scope_names[0] not in ["kernel", "scale", "embedding"]: + pointer = getattr(pointer, "weight") + if scope_names[0] != "embedding": + logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") + array = np.transpose(array) + try: + assert ( + pointer.shape == array.shape + ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + except AssertionError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array.astype(np.float32)) + tf_weights.pop(txt_name, None) + + logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.") + return model + + +#################################################### +# PyTorch Models are constructed by sub-classing +# - torch.nn.Module for the layers and +# - PreTrainedModel for the models (it-self a sub-class of nn.Module) +#################################################### +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the t5 models have the + following number of attention modules: + + - t5-small: 6 + - t5-base: 12 + - t5-large: 24 + - t5-3b: 24 + - t5-11b: 24 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using t5-3b, which has a total of 24 attention modules: + model = T5ForConditionalGeneration.from_pretrained("t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with t5-3b: + model = T5ForConditionalGeneration.from_pretrained("t5-3b") + device_map = { + 0: [0, 1, 2], + 1: [3, 4, 5, 6, 7, 8, 9], + 2: [10, 11, 12, 13, 14, 15, 16], + 3: [17, 18, 19, 20, 21, 22, 23], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +class T5LayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + + return self.weight * hidden_states + + +try: + from apex.normalization import FusedRMSNorm + + T5LayerNorm = FusedRMSNorm # noqa + + logger.info( + "Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm" + ) +except ImportError: + # using the normal T5LayerNorm + pass +except Exception: + logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") + pass + +ALL_LAYERNORM_LAYERS.append(T5LayerNorm) + + +class T5DenseActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_states = self.wi(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) + self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) + self.dropout = nn.Dropout(config.dropout_rate) + self.act = ACT2FN[config.dense_act_fn] + + def forward(self, hidden_states): + hidden_gelu = self.act(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.dropout(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + self.dropout(forwarded_states) + return hidden_states + + +class T5Attention(nn.Module): + def __init__(self, config: T5Config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding( + self.relative_attention_num_buckets, self.n_heads + ) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += ( + past_key_value[0].shape[2] if query_length is None else query_length + ) + + key_length = ( + real_seq_length if key_value_states is None else key_value_states.shape[1] + ) + + def shape(states): + """projection""" + return states.view( + batch_size, -1, self.n_heads, self.key_value_proj_dim + ).transpose(1, 2) + + def unshape(states): + """reshape""" + return ( + states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + ) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape( + self.q(hidden_states) + ) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, + self.k, + key_value_states, + past_key_value[0] if past_key_value is not None else None, + ) + value_states = project( + hidden_states, + self.v, + key_value_states, + past_key_value[1] if past_key_value is not None else None, + ) + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), + device=scores.device, + dtype=scores.dtype, + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device + ) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = ( + position_bias + mask + ) # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape( + torch.matmul(attn_weights, value_states) + ) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + present_key_value_state = ( + (key_states, value_states) if (self.is_decoder and use_cache) else None + ) + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5LayerCrossAttention(nn.Module): + def __init__(self, config): + super().__init__() + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[ + 1: + ] # add attentions if we output them + return outputs + + +class T5Block(nn.Module): + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + config, has_relative_attention_bias=has_relative_attention_bias + ) + ) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + self.layer.append(T5LayerFF(config)) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + + if past_key_value is not None: + if not self.is_decoder: + logger.warning( + "`past_key_values` is passed to the encoder. Please make sure this is intended." + ) + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states" + ) + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[ + 2: + ] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if ( + hidden_states.dtype == torch.float16 + and torch.isinf(hidden_states).any() + ): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = ( + present_key_value_state + cross_attention_outputs[1] + ) + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer + hidden_states = self.layer[-1](hidden_states) + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): + clamp_value = torch.finfo(hidden_states.dtype).max - 1000 + hidden_states = torch.clamp( + hidden_states, min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = T5Config + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["T5Block"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = ( + self.config.initializer_factor + ) # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedActDense): + module.wi_0.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_model) ** -0.5) + ) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_( + mean=0.0, std=factor * ((self.config.d_ff) ** -0.5) + ) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_( + mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5) + ) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_( + mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5) + ) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_( + mean=0.0, std=factor * ((d_model) ** -0.5) + ) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (T5Attention, T5Stack)): + module.gradient_checkpointing = value + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." + " See T5 docs for more information" + ) + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full( + input_ids.shape[:-1] + (1,), decoder_start_token_id + ) + shifted_input_ids = torch.cat( + [shifted_input_ids, input_ids[..., :-1]], dim=-1 + ) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert ( + pad_token_id is not None + ), "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5Stack(T5PreTrainedModel): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [ + T5Block(config, has_relative_attention_bias=bool(i == 0)) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + config.d_model, eps=config.layer_norm_epsilon + ) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + self.device_map = ( + get_device_map(len(self.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.block)) + self.model_parallel = True + self.first_device = ( + "cpu" + if "cpu" in self.device_map.keys() + else "cuda:" + str(min(self.device_map.keys())) + ) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + # Load onto devices + for k, v in self.device_map.items(): + for layer in v: + cuda_device = "cuda:" + str(k) + self.block[layer] = self.block[layer].to(cuda_device) + + # Set embed_tokens to first layer + self.embed_tokens = self.embed_tokens.to(self.first_device) + # Set final layer norm to last device + self.final_layer_norm = self.final_layer_norm.to(self.last_device) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def deparallelize(self): + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + for i in range(len(self.block)): + self.block[i] = self.block[i].to("cpu") + self.embed_tokens = self.embed_tokens.to("cpu") + self.final_layer_norm = self.final_layer_norm.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, new_embeddings): + self.embed_tokens = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds" + ) + + if inputs_embeds is None: + assert ( + self.embed_tokens is not None + ), "You have to initialize the model with valid token embeddings" + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values[0][0].shape[2] + seq_length + if past_key_values is not None + else seq_length + ) + + if use_cache is True: + assert ( + self.is_decoder + ), f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones( + batch_size, mask_seq_length, device=inputs_embeds.device + ) + if ( + self.is_decoder + and encoder_attention_mask is None + and encoder_hidden_states is not None + ): + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long, + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones( + encoder_hidden_shape, device=inputs_embeds.device + ) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask( + cross_attn_head_mask, self.config.num_layers + ) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate( + zip(self.block, past_key_values) + ): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to( + hidden_states.device + ) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = ( + encoder_extended_attention_mask.to(hidden_states.device) + ) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to( + hidden_states.device + ) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to( + hidden_states.device + ) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[ + 4 if output_attentions else 3 + ] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + ( + present_key_value_state, + ) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +T5_START_DOCSTRING = r""" + + The T5 model was proposed in [Exploring the Limits of Transfer Learning with a Unified Text-to-Text + Transformer](https://arxiv.org/abs/1910.10683) by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan + Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder transformer pre-trained in a + text-to-text denoising generative setting. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`T5Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +T5_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + [What are input IDs?](../glossary#input-ids) + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values` + is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). + + To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5 + Training](./t5#training). + decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also + be used by default. + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the encoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0, + 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in + `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*, `optional`: *attentions*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` is a sequence of hidden states at + the output of the last layer of the encoder. Used in the cross-attention of the decoder. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded + representation. If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be + input (see `past_key_values`). This is useful if you want more control over how to convert + `decoder_input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + + If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value + of `inputs_embeds`. + + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +T5_ENCODER_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you + should be able to pad the inputs on both the right and the left. + + Indices can be obtained using [`T5Tokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for detail. + + To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training). + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +# Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask +__HEAD_MASK_WARNING_MSG = """ +The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently, +`decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions. +If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers, +num_heads)`. +""" + + +@add_start_docstrings( + "The bare T5 Model transformer outputting raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5Model(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + decoder_inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, T5Model + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5Model.from_pretrained("t5-small") + + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1 + + >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model. + >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg. + >>> decoder_input_ids = model._shift_right(decoder_input_ids) + + >>> # forward pass + >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """T5 Model with a `language modeling` head on top.""", T5_START_DOCSTRING +) +class T5ForConditionalGeneration(T5PreTrainedModel): + _keys_to_ignore_on_load_missing = [ + r"encoder.embed_tokens.weight", + r"decoder.embed_tokens.weight", + r"lm_head.weight", + ] + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.model_dim = config.d_model + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, self.shared) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.decoder.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.decoder.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.decoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.decoder = self.decoder.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + reduction: Optional[str] = "mean", + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import T5Tokenizer, T5ForConditionalGeneration + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if ( + labels is not None + and decoder_input_ids is None + and decoder_inputs_embeds is None + ): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to( + self.decoder.first_device + ) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100, reduction=reduction) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + if reduction == "none": + loss = loss.view(lm_logits.size(0), -1).sum(1) + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning( + "You might want to consider setting `use_cache=True` to speed up decoding" + ) + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device) + ), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + ( + reordered_layer_past_states, + ) + return reordered_decoder_past + + +@add_start_docstrings( + "The bare T5 Model transformer outputting encoder's raw hidden-states without any specific head on top.", + T5_START_DOCSTRING, +) +class T5EncoderModel(T5PreTrainedModel): + authorized_missing_keys = [ + r"encoder.embed_tokens.weight", + ] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + self.device_map = ( + get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.encoder.block)) + self.encoder.parallelize(self.device_map) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + self.encoder.deparallelize() + self.encoder = self.encoder.to("cpu") + self.model_parallel = False + self.device_map = None + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + + def get_encoder(self): + return self.encoder + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads) + + @add_start_docstrings_to_model_forward(T5_ENCODER_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + r""" + Returns: + + Example: + + ```python + >>> from transformers import T5Tokenizer, T5EncoderModel + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-small") + >>> model = T5EncoderModel.from_pretrained("t5-small") + >>> input_ids = tokenizer( + ... "Studies have been shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model(input_ids=input_ids) + >>> last_hidden_states = outputs.last_hidden_state + ```""" + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return encoder_outputs diff --git a/minigpt4/models/moe/beam_search.py b/minigpt4/models/moe/beam_search.py new file mode 100644 index 0000000..c5c3a5a --- /dev/null +++ b/minigpt4/models/moe/beam_search.py @@ -0,0 +1,660 @@ +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MoELayer(nn.Module): + def __init__(self, hidden_size, expert, gate, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.route_method = route_method + self.topk = topk + self.use_balance_loss = use_balance_loss + self.weight_type = weight_type + + if route_method in ["gate-token", "gate-sentence"]: + self.gate = gate + else: + raise KeyError("Routing method not supported.") + + def _forward_gate_sentence(self, x, attention_mask): + """ + x: query_attention_output , torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + + ### Notice: + the raw version of expert_attention_mask is the extended_attention_mask, + which will be add to attention_score directly + the values of extended_attention_mask are -0.0 or -10000 + it should be adjust to 1/0 version to be processed by experts + """ + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768]) + logits_gate = self.gate(x_average) # torch.Size([bz, num_experts]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts]) + select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + + # 这里用l2 norm 去加权 + if self.weight_type == 'l2_norm': + # normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk]) + normalized_tensor = select_prob_gate + + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + gate_load = num_sentences.clone() + + # forward experts + def forward_expert(input_x, expert_idx): + input_x = self.experts[expert_idx].forward(input_x) + return input_x + + result_lst = list() + for i in range(self.topk): + # top1、top2... 分别为一组,进行gate分组之后过expert,然后乘以概率后相加 + tmp_gate = gate[:,i] + tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1) + order = tmp_gate.argsort(0) + num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0) + x1 = x[order] # reorder according to expert number + x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts + + result = [] + for i in range(self.num_experts): + if x1[i].size(0) > 0: + result.append(forward_expert(x1[i], i)) + result = torch.vstack(result) + result = result[order.argsort(0)] # restore original order + + # result_lst.append(result * tmp_prob) # result * prob + result_lst.append(result) # result * prob + + moe_result = sum(result_lst) + print('Layer Qformer MoE: \n',prob_gate) + return moe_result, select_prob_gate, gate + + def _forward_gate_sentence_post(self, x, attention_mask): + """ + x: query_attention_output; torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + bz = 4 + x = torch.randn(bz,32,768) + attention_mask = torch.ones([bz, 32]) + + """ + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + def forward_expert(input_x, expert_idx): + # input_x += torch.randn(4,32,768) + # return input_x + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = forward_expert(x_masked, expert_idx) + outputs.append(output_x.unsqueeze(0)) + + output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768]) + # gate_acore = self.gates[expert_idx](output_x_aver) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + + candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts]) + topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + gate_load = num_sentences.clone() + + # load balancing loss + if self.use_balance_loss: + balance_loss = self._balancing_loss(prob_gate, num_sentences) + else: + balance_loss = 0.0 + + # importance loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + # output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768]) + # output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768]) + # logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1]) + + prob_gate_topk = torch.zeros_like(prob_gate) + prob_gate_topk.scatter_(1, gate, topk_values) + prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) # torch.Size([bz, num_expert]) + candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768]) + results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768]) + moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768]) + import pdb;pdb.set_trace() + + return moe_result, (balance_loss+importance_loss), prob_gate_normalized + + def forward(self, x, attention_mask): + if self.route_method == "gate-token": + x, balance_loss, gate_load = self._forward_gate_token(x) + elif self.route_method == "gate-sentence": + if x.size(0) == 1: + x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask) + else: + x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask) + elif self.route_method == "gate-sentence-post": + x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask) + else: + raise KeyError("Routing method not supported.") + + return x, balance_loss, gate_load + + +class RouteMoELayer(nn.Module): + def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.num_beams = num_beams + self.hidden_size = hidden_size + self.layer_judge = layer_judge + self.weight_type = weight_type + + self.route_method = route_method + if self.route_method == "pre-route": + self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() + elif self.route_method in ["post-route", "post-route-dp"]: + gate = nn.Linear(hidden_size, 1, bias=False).float() + self.gate = gate + # self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + + def forward_gate(self, x): + """ + x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768]) + prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts]) + """ + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768]) + x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768]) + logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts]) + return prob_gate + + + def beam_search_backup(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method=='pre-route': + # current_scores_log torch.Size([bz, num_experts]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_beams * batch_size)) + + else: + if self.layer_judge=='first' and self.route_method == 'post-route': + batch_size = batch_size + next_scores_raw1 = torch.exp(current_scores_log) # torch.Size([bz, num_beams*num_experts]) + else: + batch_size = int(batch_size // self.num_beams) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + next_scores_raw1 = next_scores_exp.view( + batch_size, self.num_beams * self.num_experts + ) # torch.Size([bz, num_beams*num_experts]) + + next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True) + # next_scores torch.Size([bz, num_beams]) + # next_tokens torch.Size([bz, num_beams]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + for rank, (expert_id, expert_score) in enumerate( + zip(next_experts[batch_idx], next_scores[batch_idx]) + ): + expert_id = expert_id.item() + beam_id = expert_id // self.num_experts + ex_id = expert_id % self.num_experts + effective_beam_id = batch_idx*self.num_beams + beam_id + + next_sent_beam.append((expert_score, ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + import pdb;pdb.set_trace() + + if self.layer_judge=='first' and self.route_method == 'post-route': + beam_scores = next_scores.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) + expert_route = next_experts.view(self.num_beams * batch_size) + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route.new([x[1] for x in next_batch_beam]).unsqueeze(-1) + beam_idx = expert_route.new([int(x[2]/self.num_beams) for x in next_batch_beam]) + expert_route = beam_experts + else: + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + return beam_scores, expert_route, beam_idx + + def dp_search(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route', 'post-route-dp']: + # current_scores_log torch.Size([bz, num_experts]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_beams * batch_size)) + + else: + batch_size = int(batch_size // self.num_beams) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + import pdb;pdb.set_trace() + + next_scores_raw, next_experts_raw = torch.topk(next_scores_exp, 1, dim=1, largest=True, sorted=True) + next_scores = next_scores_raw.view(batch_size, self.num_beams) + next_experts = next_experts_raw.view(batch_size, self.num_beams) + # next_scores, next_experts = torch.topk(current_scores_log, 1, dim=1, largest=True, sorted=True) # equal 等价 + # next_scores torch.Size([bz * num_beams, 1]) + # next_tokens torch.Size([bz * num_beams, 1]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + expert_id = next_experts[batch_idx] + expert_score = next_scores[batch_idx] + values, index = torch.topk(expert_score, self.num_beams, dim=0, largest=True, sorted=True) + for i in range(self.num_beams): + beam_id = index[i].item() + ex_id = expert_id[beam_id].item() + effective_beam_id = batch_idx*self.num_beams + beam_id + next_sent_beam.append((values[i], ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + import pdb;pdb.set_trace() + + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + return beam_scores, expert_route, beam_idx + + + def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']: + # current_scores_log torch.Size([bz, num_experts]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_beams * batch_size)) + import pdb;pdb.set_trace() + + else: + batch_size = int(batch_size // self.num_beams) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + import pdb;pdb.set_trace() + + next_scores_raw1 = next_scores_exp.view( + batch_size, self.num_beams * self.num_experts + ) # torch.Size([bz, num_beams*num_experts]) + + next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True) + # next_scores torch.Size([bz, num_beams]) + # next_tokens torch.Size([bz, num_beams]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + for rank, (expert_id, expert_score) in enumerate( + zip(next_experts[batch_idx], next_scores[batch_idx]) + ): + expert_id = expert_id.item() + beam_id = expert_id // self.num_experts + ex_id = expert_id % self.num_experts + effective_beam_id = batch_idx*self.num_beams + beam_id + + next_sent_beam.append((expert_score, ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + import pdb;pdb.set_trace() + + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + print("next_scores_raw1:\n",next_scores_raw1) + + return beam_scores, expert_route, beam_idx + + def forward_expert_ffn(self, x, expert_select, current_scores): + """ + x_repeat : [bz*num_beams, 32,768] + expert_select : [bz*num_beams] + current_scores : [bz*num_beams, num_experts] / [bz, num_experts] + """ + # add_1228 l2_normalization + # normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk]) + # tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1) + import pdb;pdb.set_trace() + outputs = list() + for i in range(self.num_experts): + output_x = self.experts[i].forward(x) + outputs.append(output_x.unsqueeze(1)) + candidate_output = torch.cat(outputs, dim=1) + expert_select_matrix = F.one_hot(expert_select, self.num_experts) + + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores * expert_select_matrix + candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + import pdb;pdb.set_trace() + output = torch.sum(candidate_output, dim=1) + + return output # torch.Size([bz*num_beams, 32, 768]) + + def forward_pre_route(self, x, beam_scores, expert_route, use_log=True): + import pdb;pdb.set_trace() + current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams] + + importance_loss = self._importance_auxiliary_loss(current_scores) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + + batch_size, num_tokens = x.shape[0], x.shape[1] + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + + current_expert_select = expert_route[:,-1] + + import pdb;pdb.set_trace() + + if self.layer_judge=='first': # expand first dim to batch_size * num_beams + replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size) + x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + input_x = x[beam_idx] + candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768] + import pdb;pdb.set_trace() + + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward_post_route(self, x, beam_scores, expert_route, use_log=True): + + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + def forward_expert(input_x, expert_idx): + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = forward_expert(x_masked, expert_idx) + # output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768]) + output_x_aver = torch.mean(output_x, dim=1) + # gate_score = self.gates[expert_idx](output_x_aver) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + outputs.append(output_x.unsqueeze(0)) + + candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert]) + current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts]) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + + # importance loss + importance_loss = self._importance_auxiliary_loss(current_scores) + + batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam + import pdb; pdb.set_trace() + + if self.route_method == 'post-route': + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + elif self.route_method == 'post-route-dp': + beam_scores, expert_route, beam_idx = self.dp_search(current_scores_log, beam_scores, expert_route, batch_size) + + # beam_scores torch.Size([bz*num_beam]) + # expert_route torch.Size([bz*num_beam, layer_n]) + current_select_expert = expert_route[:,-1] + # current_select_expert torch.Size([bz*num_beam, 1]) + + # import pdb; pdb.set_trace() + + if self.layer_judge == 'first': + replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size) + candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768]) + expert_select_matrix = F.one_hot(current_select_expert, self.num_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores[beam_idx] * expert_select_matrix + output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + final_output = torch.sum(output, dim=1) + + import pdb; pdb.set_trace() + print("current_scores:\n",current_scores) + + return final_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): + """ + if first_layer: x [bz, 32, 768] + else: x [bz*num_beams, 32, 768] + + """ + if self.route_method == 'pre-route': + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) + elif self.route_method in ['post-route', 'post-route-dp']: + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True) + + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + +if __name__ == '__main__': + + import sys + sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") + from minigpt4.models.QformerRouteMoE import BertConfig + from minigpt4.models.QformerRouteMoE import FeedForward + + from minigpt4.models.moe.utils import ( + use_experts, + moe_layer_judge, + ) + vision_width = 1408 + cross_attention_freq = 2 + num_query_token = 32 + # init_QformerMoE + config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + config.encoder_width = vision_width + # insert cross-attention layer every other block + config.add_cross_attention = True + config.cross_attention_freq = cross_attention_freq + config.query_length = num_query_token + config.moebert_expert_num = 2 + config.moebert_num_beams = 2 + config.moebert_route_method = 'gate-sentence' + config.moe_topk = 2 + config.use_balance_loss = False + config.moe_weight_type = 'l2_norm' + + batch_size = 4 + x = torch.randn(batch_size, 32, 768) + beam_scores, expert_route = None, None + x1 = x + x2 = x + x3 = x + beam_scores1, expert_route1 = None, None + beam_scores2, expert_route2 = None, None + + for layer_num in [6, 8, 10]: + layer_judge = moe_layer_judge(layer_num) + ffn = FeedForward(config) + + # experts = RouteMoELayer( + # hidden_size=768, + # expert=ffn, + # num_experts=config.moebert_expert_num, + # num_beams=config.moebert_num_beams, + # layer_judge = layer_judge, + # route_method = "pre-route", + # weight_type="no_ffn_prob" + # ) + # layer_output = experts(x, None, beam_scores, expert_route) + # hidden_states1, beam_scores, expert_route, beam_idx, importance_loss = layer_output + + # print(beam_scores) + # print(expert_route) + # print(beam_idx) + # print(importance_loss) + # x = hidden_states1 + + # experts_post = RouteMoELayer( + # hidden_size=768, + # expert=ffn, + # num_experts=config.moebert_expert_num, + # num_beams=config.moebert_num_beams, + # layer_judge = layer_judge, + # route_method = "post-route", + # weight_type="ffn_prob" + # ) + # layer_output = experts_post(x1, None, beam_scores1, expert_route1, False) + # hidden_states2, beam_scores1, expert_route1, beam_idx, importance_loss = layer_output + + # print(beam_scores1) + # print(expert_route1) + # print(beam_idx) + # print(importance_loss) + # x1 = hidden_states2 + + experts_post = RouteMoELayer( + hidden_size=768, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = layer_judge, + route_method = "post-route-dp", + weight_type="ffn_prob" + ) + layer_output = experts_post(x2, None, beam_scores2, expert_route2, False) + hidden_states3, beam_scores2, expert_route2, beam_idx2, importance_loss2 = layer_output + + print(beam_scores2) + print(expert_route2) + print(beam_idx2) + print(importance_loss2) + x2 = hidden_states3 + + # gate = nn.Linear(768, config.moebert_expert_num, bias=False).float() + # experts_moe = MoELayer( + # hidden_size=config.hidden_size, + # expert=ffn, + # gate=gate, + # num_experts=config.moebert_expert_num, + # route_method=config.moebert_route_method, + # topk=config.moe_topk, + # use_balance_loss=config.use_balance_loss, + # weight_type=config.moe_weight_type, + # ) + # attn_mask = torch.ones([batch_size, 32]) + # layer_output = experts_moe(x3, attn_mask) + # hidden_states4, select_prob_gate, gate_load,_ = layer_output + + # print(select_prob_gate) + # print(gate_load) + # x3 = hidden_states4 + + print("------------------------------------") + import pdb; pdb.set_trace() + + + + def forward_post_route_backup(self, x, beam_scores, expert_route, use_log=True): + + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + def forward_expert(input_x, expert_idx): + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = forward_expert(x_masked, expert_idx) + outputs.append(output_x.unsqueeze(0)) + # output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768]) + # gate_score = self.gates[expert_idx](output_x_aver) + output_x_aver = torch.mean(output_x, dim=1) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert]) + current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts]) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + + # importance loss + importance_loss = self._importance_auxiliary_loss(current_scores) + + batch_size = x.shape[0] # bz*num_beam + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + # beam_scores torch.Size([bz*num_beam]) + # expert_route torch.Size([bz*num_beam, layer_n]) + current_select_expert = expert_route[:,-1] + # current_select_expert torch.Size([bz*num_beam, 1]) + + output = list() + for i in range(beam_idx.shape[0]): + b_idx = beam_idx[i] + ex_idx = current_select_expert[i] + ex_out = candidate_output[ex_idx, b_idx, :,:] + if self.weight_type == 'ffn_prob': + prob = current_scores[b_idx, ex_idx] + ex_out = ex_out*prob + output.append(ex_out.unsqueeze(0)) + + final_output = torch.concat(output, dim=0) + # import pdb;pdb.set_trace() + return final_output, beam_scores, expert_route, beam_idx, importance_loss + diff --git a/minigpt4/models/moe/beam_search_universal.py b/minigpt4/models/moe/beam_search_universal.py new file mode 100644 index 0000000..b435709 --- /dev/null +++ b/minigpt4/models/moe/beam_search_universal.py @@ -0,0 +1,321 @@ +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UniRouteMoELayer(nn.Module): + def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts #(1+other) + self.num_route_experts = num_experts-1 + self.num_beams = num_beams + self.num_route_beam = num_beams-1 + + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.hidden_size = hidden_size + self.layer_judge = layer_judge + self.weight_type = weight_type + + self.route_method = route_method + if self.route_method == "pre-route-uni": + self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() + elif self.route_method in ["post-route-uni"]: + gate = nn.Linear(hidden_size, 1, bias=False).float() + self.gate = gate + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + + def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method in ['pre-route-uni', 'post-route-uni']: + # current_scores_log torch.Size([bz, num_experts-1]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_route_beam, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_route_beam * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_route_beam * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_route_beam * batch_size)) + + else: + batch_size = int(batch_size // self.num_route_beam) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + + next_scores_raw1 = next_scores_exp.view( + batch_size, self.num_route_beam * self.num_route_experts + ) # torch.Size([bz, num_route_beam*num_route_experts]) + + next_scores, next_experts = torch.topk(next_scores_raw1, self.num_route_beam, dim=1, largest=True, sorted=True) + # next_tokens torch.Size([bz, num_route_beam]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + for rank, (expert_id, expert_score) in enumerate( + zip(next_experts[batch_idx], next_scores[batch_idx]) + ): + expert_id = expert_id.item() + beam_id = expert_id // self.num_route_experts + ex_id = expert_id % self.num_route_experts + effective_beam_id = batch_idx*self.num_route_beam + beam_id + + next_sent_beam.append((expert_score, ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + return beam_scores, expert_route, beam_idx + + + def forward_gate(self, x): + """ + TODO: Pre forward gate + x : torch.Size([bz*(num_beams-1), 32, 768]) or torch.Size([bz, 32, 768]) + prob_gate : torch.Size([bz*(num_beams-1), num_experts]) or torch.Size([bz, num_experts]) + """ + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*(num_beams-1), 32, 768]) + x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*(num_beams-1), 768]) + logits_gate = self.gate(x_average) # torch.Size([bz*(num_beams-1), num_experts]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*(num_beams-1), num_experts]) + return prob_gate + + def forward_expert_ffn(self, x, expert_select, current_scores): + """ + x_repeat : [bz*num_beams, 32,768] + expert_select : [bz*num_beams] + current_scores : [bz*num_beams, num_experts] / [bz, num_experts] + """ + # import pdb;pdb.set_trace() + outputs = list() + for i in range(self.num_experts-1): + output_x = self.experts[i].forward(x) + outputs.append(output_x.unsqueeze(1)) + candidate_output = torch.cat(outputs, dim=1) + expert_select_matrix = F.one_hot(expert_select, self.num_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores * expert_select_matrix + candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + output = torch.sum(candidate_output, dim=1) + # import pdb;pdb.set_trace() + return output # torch.Size([bz*(num_beams-1), 32, 768]) + + def forward_pre_route(self, x, beam_scores, expert_route, use_log=True): + + current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams] + + importance_loss = self._importance_auxiliary_loss(current_scores) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + # import pdb;pdb.set_trace() + batch_size, num_tokens = x.shape[0], x.shape[1] + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + current_expert_select = expert_route[:,-1] + + if self.layer_judge=='first': # expand first dim to batch_size * num_beams + replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size) + x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + input_x = x[beam_idx] + candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768] + # import pdb;pdb.set_trace() + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward_post_route_uni(self, x, beam_scores, expert_route, use_log=True): + + if beam_scores == None: + batch_size = x.shape[0] + x_masked, x_uniexpert = x, x # torch.Size([bz, 32, 768]) + elif x.shape[0]/self.num_beams == beam_scores.shape[0]/self.num_route_beam: + batch_size = int(x.shape[0]/self.num_beams) + select_universal = [i*self.num_beams+self.num_route_beam for i in range(batch_size)] + select_expert = [ x for x in range(batch_size*self.num_beams) if x not in select_universal] + x_masked, x_uniexpert = x[select_expert],x[select_universal] + num_tokens = x.shape[1] + + import pdb; pdb.set_trace() + + def forward_expert(input_x, expert_idx): + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + #################### + ### route expert + #################### + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_route_experts): # num_expert-1 + output_x = forward_expert(x_masked, expert_idx) + output_x_aver = torch.mean(output_x, dim=1) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + outputs.append(output_x.unsqueeze(0)) + + candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert-1, bz*(num_beam-1), 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*(num_beam-1), num_expert-1]) + current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*(num_beam-1), num_expert-1]) + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 torch.Size([bz*(num_beam-1), num_expert-1]) + else: + current_scores_log = current_scores + + importance_loss = self._importance_auxiliary_loss(current_scores) + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, current_scores_log.shape[0]) + # beam_scores torch.Size([bz*(num_beam-1)]), expert_route torch.Size([bz*(num_beam-1), layer_n]) + current_select_expert = expert_route[:,-1] # torch.Size([bz*(num_beam-1)]) + + import pdb; pdb.set_trace() + if self.layer_judge == 'first': + replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_route_experts, batch_size, self.num_route_beam, num_tokens, self.hidden_size) + candidate_output_raw = replicated_tensor.contiguous().view(self.num_route_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_route_beam, self.num_route_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_route_experts) # [bz*(num_beams-1), num_experts-1] + + import pdb; pdb.set_trace() + candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768]) + expert_select_matrix = F.one_hot(current_select_expert, self.num_route_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores[beam_idx] * expert_select_matrix + output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + experts_output = torch.sum(output, dim=1) # [bz*num_beams-1, 32, 768] + + import pdb; pdb.set_trace() + + #################### + ### universal expert + #################### + uni_output = forward_expert(x_uniexpert, self.num_experts-1) # [bz, 32, 768] + + #################### + ### Combine expert + #################### + output = list() + for i in range(batch_size): + expert_tmp = experts_output[i*self.num_route_beam: i*self.num_route_beam+self.num_route_beam,:,:] + combine_tmp = torch.cat((expert_tmp, uni_output[i].unsqueeze(0))) + output.append(combine_tmp) + final_output = torch.cat(output) # [bz*num_beam, 32 ,768] + + import pdb; pdb.set_trace() + + return final_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): + """ + if first_layer: x [bz, 32, 768] + else: x [bz*num_beams, 32, 768] + """ + if self.route_method == 'pre-route-uni': + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) + elif self.route_method in ['post-route-uni']: + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route_uni(x, beam_scores, expert_route, use_log=True) + + import pdb;pdb.set_trace() + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + + + +if __name__ == '__main__': + + import sys + sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") + from minigpt4.models.QformerRouteMoE import BertConfig + from minigpt4.models.QformerRouteMoE import FeedForward + + from minigpt4.models.moe.utils import ( + use_experts, + moe_layer_judge, + ) + vision_width = 1408 + cross_attention_freq = 2 + num_query_token = 32 + # init_QformerMoE + config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + config.encoder_width = vision_width + # insert cross-attention layer every other block + config.add_cross_attention = True + config.cross_attention_freq = cross_attention_freq + config.query_length = num_query_token + config.moebert_expert_num = 3 + config.moebert_num_beams = 3 + config.moebert_route_method = 'gate-sentence' + config.moe_topk = 3 + config.use_balance_loss = False + config.moe_weight_type = 'l2_norm' + + batch_size = 4 + x = torch.randn(batch_size, 32, 768) + beam_scores, expert_route = None, None + x1 = x + x2 = x + x3 = x + beam_scores1, expert_route1 = None, None + beam_scores2, expert_route2 = None, None + + for layer_num in [6, 8, 10]: + layer_judge = moe_layer_judge(layer_num) + ffn = FeedForward(config) + + # experts_post = RouteMoELayer( + # hidden_size=768, + # expert=ffn, + # num_experts=config.moebert_expert_num, + # num_beams=config.moebert_num_beams, + # layer_judge = layer_judge, + # route_method = "post-route", + # weight_type="ffn_prob" + # ) + # layer_output = experts_post(x1, None, beam_scores1, expert_route1, False) + # hidden_states2, beam_scores1, expert_route1, beam_idx, importance_loss = layer_output + + # print(beam_scores1) + # print(expert_route1) + # print(beam_idx) + # print(importance_loss) + # x1 = hidden_states2 + + experts_post = UniRouteMoELayer( + hidden_size=768, + expert=ffn, + num_experts=config.moebert_expert_num, + num_beams=config.moebert_num_beams, + layer_judge = layer_judge, + route_method = "post-route-uni", + weight_type="ffn_prob" + ) + layer_output = experts_post(x2, None, beam_scores2, expert_route2, False) + hidden_states3, beam_scores2, expert_route2, beam_idx2, importance_loss2 = layer_output + + print(beam_scores2) + print(expert_route2) + print(beam_idx2) + print(importance_loss2) + x2 = hidden_states3 + + print("------------------------------------") + import pdb; pdb.set_trace() diff --git a/minigpt4/models/moe/moe_layer.py b/minigpt4/models/moe/moe_layer.py new file mode 100644 index 0000000..1730962 --- /dev/null +++ b/minigpt4/models/moe/moe_layer.py @@ -0,0 +1,258 @@ +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MoELayer(nn.Module): + def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='raw_prob'): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.route_method = route_method + self.topk = topk + self.use_balance_loss = use_balance_loss + self.weight_type = weight_type + + if route_method in ["gate-token", "gate-sentence", "gate-sentence-cls"]: + gate = nn.Linear(hidden_size, num_experts, bias=False).float() + elif route_method in ["gate-sentence-post"]: + gate = nn.Linear(hidden_size, 1, bias=False).float() + # self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) + elif route_method in ["gate-sentence-post-cosine"]: + gate = nn.Parameter(torch.rand(hidden_size)).float() + else: + raise KeyError("Routing method not supported.") + self.gate = gate + + def _balancing_loss(self, prob_gate, num_tokens): + # From MOEBERT + # compute the load balancing loss + # prob_gate,是 [bz, num_expert],每个样本被分配给每个expert的概率 + # 等价于 VMOE 中 _gshard_auxiliary_loss + P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率 + temp = num_tokens.float() + f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例 + balance_loss = self.num_experts * torch.sum(P * f) + return balance_loss + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + def _forward_gate_token(self, x): + bsz, seq_len, dim = x.size() + + x = x.view(-1, dim) + logits_gate = self.gate(x) + prob_gate = F.softmax(logits_gate, dim=-1) + gate = torch.argmax(prob_gate, dim=-1) + + order = gate.argsort(0) + num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) + gate_load = num_tokens.clone() + x = x[order] # reorder according to expert number + x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts + + # compute the load balancing loss + P = prob_gate.mean(0) + temp = num_tokens.float() + f = temp / temp.sum(0, keepdim=True) + balance_loss = self.num_experts * torch.sum(P * f) + + prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) + prob_gate = prob_gate[order] + prob_gate = prob_gate.split(num_tokens.tolist(), dim=0) + + def forward_expert(input_x, prob_x, expert_idx): + input_x = self.experts[expert_idx].forward(input_x) + input_x = input_x * prob_x + return input_x + + x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)] + x = torch.vstack(x) + x = x[order.argsort(0)] # restore original order + x = x.view(bsz, seq_len, dim) + + return x, balance_loss, gate_load + + def _forward_gate_sentence_post(self, x, attention_mask): + """ + x: query_attention_output; torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + bz = 4 + x = torch.randn(bz,32,768) + attention_mask = torch.ones([bz, 32]) + + """ + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + def forward_expert(input_x, expert_idx): + # input_x += torch.randn(4,32,768) + # return input_x + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = forward_expert(x_masked, expert_idx) + outputs.append(output_x.unsqueeze(0)) + output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768]) + # gate_acore = self.gates[expert_idx](output_x_aver) + if self.route_method=="gate-sentence-post-cosine": + # gate_score = F.cosine_similarity(self.gate.weight, output_x_aver,dim=1).unsqueeze(1) + gate_score = F.cosine_similarity(self.gate, output_x_aver,dim=1).unsqueeze(1) + else: + gate_score = self.gate(output_x_aver) + + logits_gate_lst.append(gate_score) + + candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts]) + topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + gate_load = num_sentences.clone() + + # load balancing loss + if self.use_balance_loss: + balance_loss = self._balancing_loss(prob_gate, num_sentences) + else: + balance_loss = 0.0 + + # importance loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + prob_gate_topk = torch.zeros_like(prob_gate) + prob_gate_topk.scatter_(1, gate, topk_values) + + if self.weight_type == 'average': + # torch.Size([bz, num_expert]) 未选中的expert prob_gate_norm为0 + prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) + elif self.weight_type == 'raw_prob': + prob_gate_normalized = prob_gate_topk + elif self.weight_type == 'softmax_norm': + prob_gate_normalized = F.softmax(prob_gate_topk, dim=-1) # torch.Size([bz, num_expert]) + + candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768]) + results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768]) + moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768]) + # import pdb;pdb.set_trace() + + return moe_result, (balance_loss+importance_loss), prob_gate_normalized + + def router(self, x, attention_mask): + # Prepare input x + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + x_average = torch.mean(x_masked, dim=1) # torch.Size([bz, 768]) + + # Forward Gate + # logits_gate: [bz, num_experts] + logits_gate = self.gate(x_average) + + # Probabilities for each sample of what expert it should be sent to. + # prob_gate: [bz, num_experts] + prob_gate = F.softmax(logits_gate, dim=-1) + + # Get Top-K experts for each sample + # gate: [bz, topk] + # select_prob_gate: [bz, topk] + select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) + + # Reshap Prob_gate & Gate + # expert_mask: [batch_size, topk, num_experts] + # expert_gate: [batch_size, topk, num_experts] + # combine_tensor: [batch_size, num_experts] + expert_mask = F.one_hot(gate, self.num_experts) + expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask + combine_tensor = torch.sum(expert_gate, dim=1) + + # Calculate Balancing Loss + if self.use_balance_loss: + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + balance_loss = self._balancing_loss(prob_gate, num_sentences) + else: + balance_loss = 0.0 + + # Calculate Importance Loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + return expert_mask, combine_tensor, balance_loss, importance_loss + + def cls_router(self, cls_hidden=None): + + logits_gate = self.gate(cls_hidden.squeeze(1)) + prob_gate = F.softmax(logits_gate, dim=-1) + select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) + expert_mask = F.one_hot(gate, self.num_experts) + expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask + combine_tensor = torch.sum(expert_gate, dim=1) + + if self.use_balance_loss: + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + balance_loss = self._balancing_loss(prob_gate, num_sentences) + else: + balance_loss = 0.0 + + importance_loss = self._importance_auxiliary_loss(prob_gate) + return expert_mask, combine_tensor, balance_loss, importance_loss + + def _forward_gate_sentence(self, x, attention_mask, cls_hidden=None): + """ + x: query_attention_output , torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + + ### Notice: + the raw version of expert_attention_mask is the extended_attention_mask, + which will be add to attention_score directly + the values of extended_attention_mask are -0.0 or -10000 + it should be adjust to 1/0 version to be processed by experts + """ + # Forward Router + if self.route_method=="gate-sentence-cls": + expert_mask, combine_tensor, balance_loss, importance_loss = self.cls_router(cls_hidden) + else: + expert_mask, combine_tensor, balance_loss, importance_loss = self.router(x, attention_mask) + + # Forward Expert FFN + result = [] + for expert_idx in range(self.num_experts): + output_x = self.experts[expert_idx].forward(x) + result.append(output_x.unsqueeze(0)) + expert_output = torch.cat(result).permute(1,0,2,3) # torch.Size([batch_size, num_expert, num_tokens, hidden_states]) + + # multiply outputs of experts by the routing probability + if self.weight_type == 'raw_prob': + expert_outputs_combined = expert_output * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states]) + elif self.weight_type == 'no_prob': + combine_index = torch.sum(expert_mask, dim=1) + expert_outputs_combined = expert_output * combine_index.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states]) + + outputs = torch.sum(expert_outputs_combined, dim=1) # torch.Size([batch_size, num_tokens, hidden_states]) + + return outputs, (balance_loss+importance_loss), combine_tensor + + def forward(self, x, attention_mask, cls_hidden=None): + # import pdb; pdb.set_trace() + + if self.route_method == "gate-token": + x, balance_loss, gate_load = self._forward_gate_token(x) + elif self.route_method == "gate-sentence": + x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask) + elif self.route_method in ["gate-sentence-post", "gate-sentence-post-cosine"]: + x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask) + elif self.route_method == "gate-sentence-cls": + x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask, cls_hidden) + else: + raise KeyError("Routing method not supported.") + return x, balance_loss, gate_load diff --git a/minigpt4/models/moe/moe_layer_backup.py b/minigpt4/models/moe/moe_layer_backup.py new file mode 100644 index 0000000..25f2e59 --- /dev/null +++ b/minigpt4/models/moe/moe_layer_backup.py @@ -0,0 +1,330 @@ +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MoELayer(nn.Module): + def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='l2_norm'): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.route_method = route_method + self.topk = topk + self.use_balance_loss = use_balance_loss + self.weight_type = weight_type + + if route_method in ["gate-token", "gate-sentence"]: + self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() + elif route_method in ["gate-sentence-post"]: + gate = nn.Linear(hidden_size, 1, bias=False).float() + # self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) + self.gate = gate + else: + raise KeyError("Routing method not supported.") + + def _balancing_loss(self, prob_gate, num_tokens): + # From MOEBERT + # compute the load balancing loss + # prob_gate,是 [bz, num_expert],每个样本被分配给每个expert的概率 + # 等价于 VMOE 中 _gshard_auxiliary_loss + P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率 + temp = num_tokens.float() + f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例 + balance_loss = self.num_experts * torch.sum(P * f) + return balance_loss + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + def _forward_gate_token(self, x): + bsz, seq_len, dim = x.size() + + x = x.view(-1, dim) + logits_gate = self.gate(x) + prob_gate = F.softmax(logits_gate, dim=-1) + gate = torch.argmax(prob_gate, dim=-1) + + order = gate.argsort(0) + num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) + gate_load = num_tokens.clone() + x = x[order] # reorder according to expert number + x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts + + # compute the load balancing loss + P = prob_gate.mean(0) + temp = num_tokens.float() + f = temp / temp.sum(0, keepdim=True) + balance_loss = self.num_experts * torch.sum(P * f) + + prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) + prob_gate = prob_gate[order] + prob_gate = prob_gate.split(num_tokens.tolist(), dim=0) + + def forward_expert(input_x, prob_x, expert_idx): + input_x = self.experts[expert_idx].forward(input_x) + input_x = input_x * prob_x + return input_x + + x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)] + x = torch.vstack(x) + x = x[order.argsort(0)] # restore original order + x = x.view(bsz, seq_len, dim) + + return x, balance_loss, gate_load + + def _forward_gate_sentence_top1_raw(self, x, attention_mask): + """ + x: query_attention_output , torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + + ### Notice: + the raw version of expert_attention_mask is the extended_attention_mask, + which will be add to attention_score directly + the values of extended_attention_mask are -0.0 or -10000 + it should be adjust to 1/0 version to be processed by experts + """ + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768]) + logits_gate = self.gate(x_average) # torch.Size([bz, num_experts]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts]) + gate = torch.argmax(prob_gate, dim=-1) # torch.Size([bz]) + + order = gate.argsort(0) + num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0) + gate_load = num_sentences.clone() + x = x[order] # reorder according to expert number + x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts + + # compute the load balancing loss + P = prob_gate.mean(0) + temp = num_sentences.float() + f = temp / temp.sum(0, keepdim=True) + balance_loss = self.num_experts * torch.sum(P * f) + + prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) + prob_gate = prob_gate[order] + prob_gate = prob_gate.split(num_sentences.tolist(), dim=0) + + def forward_expert(input_x, prob_x, expert_idx): + input_x = self.experts[expert_idx].forward(input_x) + input_x = input_x * prob_x.unsqueeze(-1) + return input_x + + result = [] + for i in range(self.num_experts): + if x[i].size(0) > 0: + result.append(forward_expert(x[i], prob_gate[i], i)) + result = torch.vstack(result) + result = result[order.argsort(0)] # restore original order + + return result, balance_loss, gate_load + + def _forward_gate_sentence_post(self, x, attention_mask): + """ + x: query_attention_output; torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + bz = 4 + x = torch.randn(bz,32,768) + attention_mask = torch.ones([bz, 32]) + + """ + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + def forward_expert(input_x, expert_idx): + # input_x += torch.randn(4,32,768) + # return input_x + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = forward_expert(x_masked, expert_idx) + outputs.append(output_x.unsqueeze(0)) + + output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768]) + # gate_acore = self.gates[expert_idx](output_x_aver) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + + candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts]) + topk_values, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + gate_load = num_sentences.clone() + + # load balancing loss + if self.use_balance_loss: + balance_loss = self._balancing_loss(prob_gate, num_sentences) + else: + balance_loss = 0.0 + + # importance loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + # output_average = candidate_output.sum(2) / candidate_attn_mask.unsqueeze(-1).sum(2) # torch.Size([num_expert, bz, 768]) + # output_average = torch.permute(output_average, (1, 0, 2)) # torch.Size([bz, num_expert, 768]) + # logits_gate = self.gate(output_average) # torch.Size([bz, num_experts, 1]) + + prob_gate_topk = torch.zeros_like(prob_gate) + prob_gate_topk.scatter_(1, gate, topk_values) + + if self.weight_type == 'average': + # torch.Size([bz, num_expert]) 未选中的expert prob_gate_norm为0 + prob_gate_normalized = prob_gate_topk / prob_gate_topk.sum(dim=1, keepdim=True) + elif self.weight_type == 'raw_prob': + prob_gate_normalized = prob_gate_topk + elif self.weight_type == 'softmax_norm': + prob_gate_normalized = F.softmax(prob_gate_topk, dim=-1) # torch.Size([bz, num_expert]) + + candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768]) + results = prob_gate_normalized.unsqueeze(-1).unsqueeze(-1) * candidate_output_ad # torch.Size([bz, num_expert, 32, 768]) + moe_result = torch.sum(results, dim=1) # torch.Size([bz, 32, 768]) + # import pdb;pdb.set_trace() + + return moe_result, (balance_loss+importance_loss), prob_gate_normalized + + # def _forward_gate_sentence(self, x, attention_mask): + + # attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + # x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + # x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) + # logits_gate = self.gate(x_average) + # prob_gate = F.softmax(logits_gate, dim=-1) + # gate = torch.argmax(prob_gate, dim=-1) + + # order = gate.argsort(0) + # num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0) + # gate_load = num_sentences.clone() + # x = x[order] # reorder according to expert number + # x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts + + # # compute the load balancing loss + # P = prob_gate.mean(0) + # temp = num_sentences.float() + # f = temp / temp.sum(0, keepdim=True) + # balance_loss = self.num_experts * torch.sum(P * f) + + # prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) + # prob_gate = prob_gate[order] + # prob_gate = prob_gate.split(num_sentences.tolist(), dim=0) + + # def forward_expert(input_x, prob_x, expert_idx): + # input_x = self.experts[expert_idx].forward(input_x) + # input_x = input_x * prob_x.unsqueeze(-1) + # return input_x + + # result = [] + # for i in range(self.num_experts): + # if x[i].size(0) > 0: + # result.append(forward_expert(x[i], prob_gate[i], i)) + # result = torch.vstack(result) + # result = result[order.argsort(0)] # restore original order + + # return result, balance_loss, gate_load + + def _forward_gate_sentence(self, x, attention_mask): + """ + x: query_attention_output , torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + + ### Notice: + the raw version of expert_attention_mask is the extended_attention_mask, + which will be add to attention_score directly + the values of extended_attention_mask are -0.0 or -10000 + it should be adjust to 1/0 version to be processed by experts + """ + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz, 768]) + logits_gate = self.gate(x_average) # torch.Size([bz, num_experts]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts]) + select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + + # 这里用l2 norm 去加权 + if self.weight_type == 'l2_norm': + # actually neigther dim=0 nor dim=1 is right + normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=1) # L2 Normalization torch.Size([bz, topk]) + elif self.weight_type == 'l2_norm_0': + normalized_tensor = torch.nn.functional.normalize(select_prob_gate, p=2, dim=0) # L2 Normalization torch.Size([bz, topk]) + elif self.weight_type == 'average': + normalized_tensor = select_prob_gate / select_prob_gate.sum(dim=1, keepdim=True) + elif self.weight_type == 'raw_prob': + normalized_tensor = select_prob_gate + + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + gate_load = num_sentences.clone() + + # load balancing loss + if self.use_balance_loss: + balance_loss = self._balancing_loss(prob_gate, num_sentences) + else: + balance_loss = 0.0 + + # importance loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + # forward experts + def forward_expert(input_x, expert_idx): + input_x = self.experts[expert_idx].forward(input_x) + return input_x + + result_lst = list() + for i in range(self.topk): + # top1、top2... 分别为一组,进行gate分组之后过expert,然后乘以概率后相加 + tmp_gate = gate[:,i] + tmp_prob = normalized_tensor[:,i].unsqueeze(-1).unsqueeze(-1) + order = tmp_gate.argsort(0) + num_sentences_t = F.one_hot(tmp_gate, self.num_experts).gt(0).sum(0) + x1 = x[order] # reorder according to expert number + x1 = x1.split(num_sentences_t.tolist(), dim=0) # a list of length self.num_experts + + result = [] + for i in range(self.num_experts): + if x1[i].size(0) > 0: + result.append(forward_expert(x1[i], i)) + result = torch.vstack(result) + result = result[order.argsort(0)] # restore original order + result_lst.append(result * tmp_prob) # result * prob + # result_lst.append(result) # result * prob # add_1212 + + moe_result = sum(result_lst) + return moe_result, (balance_loss+importance_loss), gate + + def _forward_sentence_single_expert(self, x, attention_mask): + x_masked = x * attention_mask.unsqueeze(-1) + x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) + logits_gate = self.gate(x_average) + prob_gate = F.softmax(logits_gate, dim=-1) + gate = torch.argmax(prob_gate, dim=-1) + + gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0) + x = self.experts[gate.cpu().item()].forward(x) + return x, 0.0, gate_load + + def forward(self, x, attention_mask): + if self.route_method == "gate-token": + x, balance_loss, gate_load = self._forward_gate_token(x) + elif self.route_method == "gate-sentence": + if x.size(0) == 1: + x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask) + else: + x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask) + elif self.route_method == "gate-sentence-post": + x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask) + else: + raise KeyError("Routing method not supported.") + # import pdb; pdb.set_trace() + return x, balance_loss, gate_load diff --git a/minigpt4/models/moe/prompt_moe.py b/minigpt4/models/moe/prompt_moe.py new file mode 100644 index 0000000..8ea4cea --- /dev/null +++ b/minigpt4/models/moe/prompt_moe.py @@ -0,0 +1,166 @@ +import math +import os +import copy +import pickle +import torch +from torch import nn +from torch import nn +import torch.nn.functional as F +from minigpt4.models.Qformer import BertConfig + +def init_query_token_candidates(num_query_token, num_cand): + encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + query_token_candidates = nn.Parameter( + torch.zeros(num_cand, num_query_token, encoder_config.hidden_size) + ) + query_token_candidates.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return query_token_candidates + +class PromptMoEBase(nn.Module): + def __init__(self, hidden_size, num_experts): + super(PromptMoEBase, self).__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + + def _balancing_loss(self, prob_gate, num_tokens): + # From MOEBERT + # compute the load balancing loss + # prob_gate,是 [bz, num_expert],每个样本被分配给每个expert的概率 + # 等价于 VMOE 中 _gshard_auxiliary_loss + P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率 + temp = num_tokens.float() + f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例 + balance_loss = self.num_experts * torch.sum(P * f) + return balance_loss + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + def _weighted_select_expert(self, expert_ids, prob_gate_i, query_token_candidates): + # expert_ids: torch.Size([topk]) 为sample选出的topk个expert的idx + # prob_gate_i: torch.Size([topk]) 该sample对应expert的概率值 + # query_token_candidates: torch.Size([num_expert, 32, 768]) + # 先对 prob_gate 归一化,加权平均 expert_qt 的值 + weight = [prob_gate_i[expert_id].item() for expert_id in expert_ids] + weight_norm = torch.tensor(weight) / torch.tensor(weight).sum() + select_qts = [query_token_candidates[expert_id] for expert_id in expert_ids] + weighted_qt = [select_qts[i] * weight_norm[i] for i in range(weight_norm.shape[0])] + select = sum(weighted_qt).unsqueeze(0) + return select + +class PostPromptMoE(PromptMoEBase): + def __init__(self, hidden_size, num_experts, topk=1): + super(PostPromptMoE, self).__init__(hidden_size, num_experts) + self.gate = nn.Linear(hidden_size, 1, bias=False).float() + self.topk = topk + + def _forward_gate_text_single_token(self, text_embeds, candi_query_tokens): + # text embedding output from the blip2: torch.Size([bz, num_qt_candidates, 768]) + # candidate query tokens to be selected : torch.Size([bz*num_qt_candidates, 32, 768]) + logits_gate = self.gate(text_embeds).squeeze(2) # torch.Size([bz, num_qt_candidates]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_qt_candidates]) + + _, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + num_tokens = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + gate_load = num_tokens.clone() + + # load balancing loss + # balance_loss = self._balancing_loss(prob_gate, num_tokens) + balance_loss = 0.0 + + # importance loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + # select expert(query_token) for each sample + out = [self._weighted_select_expert(gate[i], prob_gate[i], candi_query_tokens[i*self.num_experts:(i+1)*self.num_experts]) for i in range(gate.shape[0])] + out = torch.vstack(out) # [bz, 32, 768] + return out, balance_loss, importance_loss, gate_load, gate + + +class PrePromptMoE(PromptMoEBase): + def __init__(self, hidden_size, num_experts, query_token_candidates, route_method, topk=1): + super(PrePromptMoE, self).__init__(hidden_size, num_experts) + self.query_token_candidates = query_token_candidates + self.route_method = route_method + self.topk = topk + if route_method in ["gate-token", "gate-single-token", "gate-sentence"]: + self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() + else: + raise KeyError("Routing method not supported.") + + def _forward_gate_single_token(self, x): + bsz, seq_len, dim = x.size() + + x = x.view(-1, dim) + logits_gate = self.gate(x) # torch.Size([bz, num_expert]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_expert]) + + _, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + num_tokens = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + + # gate = torch.argmax(prob_gate, dim=-1) # 每个样本被分配的expert + # num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) + gate_load = num_tokens.clone() + + # load balancing loss + balance_loss = self._balancing_loss(prob_gate, num_tokens) + + # importance loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + # select expert(query_token) for each sample + + out = [self._weighted_select_expert(gate[i], prob_gate[i], self.query_token_candidates) for i in range(gate.shape[0])] + + out = torch.vstack(out) # [bz, 32, 768] + + return out, balance_loss, importance_loss, gate_load, gate + + def _forward_gate_token(self, x): + bsz, seq_len, dim = x.size() + + x = x.view(-1, dim) + logits_gate = self.gate(x) # torch.Size([bz, num_expert]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_expert]) + gate = torch.argmax(prob_gate, dim=-1) # 每个样本被分配的expert + + order = gate.argsort(0) # index of sorted gate(ascending) + num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + gate_load = num_tokens.clone() + x = x[order] # reorder according to expert number + x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts + + # load balancing loss + balance_loss = self._balancing_loss(prob_gate, num_tokens) + + prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) + prob_gate = prob_gate[order] + prob_gate = prob_gate.split(num_tokens.tolist(), dim=0) # prob_gate tuple,根据expert分组 + + def select_expert(prob_x, expert_idx): + input_x = self.query_token_candidates[expert_idx] # [1, 32, 768] + # input_x = input_x * prob_x + input_x = input_x.expand(prob_x.shape[0], -1, -1) + return input_x + + out = [select_expert(prob_gate[i], i) for i in range(self.num_experts)] + out = torch.vstack(out) + out = out[order.argsort(0)] # restore original order + + return out, balance_loss, gate_load, gate + + def _forward_gate_sentence(self, x, attention_mask): + ### TODO: refer MOEBERT + return None + + def _forward_sentence_single_expert(self, x, attention_mask): + ### TODO: refer MOEBERT + return None + diff --git a/minigpt4/models/moe/route_moe_layer.py b/minigpt4/models/moe/route_moe_layer.py new file mode 100644 index 0000000..39ecb18 --- /dev/null +++ b/minigpt4/models/moe/route_moe_layer.py @@ -0,0 +1,265 @@ +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + +class RouteMoELayer(nn.Module): + def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.num_beams = num_beams + self.hidden_size = hidden_size + self.layer_judge = layer_judge + self.weight_type = weight_type + + self.route_method = route_method + if self.route_method == "pre-route": + self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() + elif self.route_method in ["post-route", "post-route-dp"]: + gate = nn.Linear(hidden_size, 1, bias=False).float() + self.gate = gate + # self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + + def forward_gate(self, x): + """ + x : torch.Size([bz*num_beams, 32, 768]) or torch.Size([bz, 32, 768]) + prob_gate : torch.Size([bz*num_beams, num_experts]) or torch.Size([bz, num_experts]) + """ + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*num_beams, 32, 768]) + # x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beams, 768]) + x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*num_beams, 768]) + logits_gate = self.gate(x_average) # torch.Size([bz*num_beams, num_experts]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beams, num_experts]) + return prob_gate + + def dp_search(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method in ['post-route-dp']: + # current_scores_log torch.Size([bz, num_experts]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_beams * batch_size)) + + else: + batch_size = int(batch_size // self.num_beams) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + + next_scores_raw, next_experts_raw = torch.topk(next_scores_exp, 1, dim=1, largest=True, sorted=True) + next_scores = next_scores_raw.view(batch_size, self.num_beams) + next_experts = next_experts_raw.view(batch_size, self.num_beams) + # next_scores, next_experts = torch.topk(current_scores_log, 1, dim=1, largest=True, sorted=True) # equal 等价 + # next_scores torch.Size([bz * num_beams, 1]) + # next_tokens torch.Size([bz * num_beams, 1]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + expert_id = next_experts[batch_idx] + expert_score = next_scores[batch_idx] + values, index = torch.topk(expert_score, self.num_beams, dim=0, largest=True, sorted=True) + for i in range(self.num_beams): + beam_id = index[i].item() + ex_id = expert_id[beam_id].item() + effective_beam_id = batch_idx*self.num_beams + beam_id + next_sent_beam.append((values[i], ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + return beam_scores, expert_route, beam_idx + + def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method in ['pre-route', 'post-route']: + # current_scores_log torch.Size([bz, num_experts]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_beams, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_beams * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_beams * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_beams * batch_size)) + + else: + batch_size = int(batch_size // self.num_beams) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + + next_scores_raw1 = next_scores_exp.view( + batch_size, self.num_beams * self.num_experts + ) # torch.Size([bz, num_beams*num_experts]) + + next_scores, next_experts = torch.topk(next_scores_raw1, self.num_beams, dim=1, largest=True, sorted=True) + # next_scores torch.Size([bz, num_beams]) + # next_tokens torch.Size([bz, num_beams]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + for rank, (expert_id, expert_score) in enumerate( + zip(next_experts[batch_idx], next_scores[batch_idx]) + ): + expert_id = expert_id.item() + beam_id = expert_id // self.num_experts + ex_id = expert_id % self.num_experts + effective_beam_id = batch_idx*self.num_beams + beam_id + + next_sent_beam.append((expert_score, ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + return beam_scores, expert_route, beam_idx + + def forward_expert_ffn(self, x, expert_select, current_scores): + """ + x_repeat : [bz*num_beams, 32,768] + expert_select : [bz*num_beams] + current_scores : [bz*num_beams, num_experts] / [bz, num_experts] + """ + # add_1228 l2_normalization + # normalized_tensor = torch.nn.functional.normalize(current_scores, p=2, dim=0) # L2 Normalization torch.Size([bz, topk]) + # tmp_prob = normalized_tensor.unsqueeze(-1).unsqueeze(-1) + # import pdb;pdb.set_trace() + outputs = list() + for i in range(self.num_experts): + output_x = self.experts[i].forward(x) + outputs.append(output_x.unsqueeze(1)) + candidate_output = torch.cat(outputs, dim=1) + expert_select_matrix = F.one_hot(expert_select, self.num_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores * expert_select_matrix + candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + output = torch.sum(candidate_output, dim=1) + # import pdb;pdb.set_trace() + return output # torch.Size([bz*num_beams, 32, 768]) + + def forward_pre_route(self, x, beam_scores, expert_route, use_log=True): + + current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams] + + importance_loss = self._importance_auxiliary_loss(current_scores) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + # import pdb;pdb.set_trace() + batch_size, num_tokens = x.shape[0], x.shape[1] + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + current_expert_select = expert_route[:,-1] + + if self.layer_judge=='first': # expand first dim to batch_size * num_beams + replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size) + x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + input_x = x[beam_idx] + candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768] + # import pdb;pdb.set_trace() + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward_post_route(self, x, beam_scores, expert_route, use_log=True): + + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + def forward_expert(input_x, expert_idx): + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = forward_expert(x_masked, expert_idx) + # output_x_aver = output_x.sum(1) / attention_mask.unsqueeze(-1).sum(1) # torch.Size([bz*num_beam, 768]) + output_x_aver = torch.mean(output_x, dim=1) + # gate_score = self.gates[expert_idx](output_x_aver) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + outputs.append(output_x.unsqueeze(0)) + + candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert, bz*num_beam, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*num_beam, num_expert]) + current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*num_beam, num_experts]) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + + # importance loss + importance_loss = self._importance_auxiliary_loss(current_scores) + + batch_size, num_tokens = x.shape[0], x.shape[1] # bz*num_beam + + if self.route_method == 'post-route': + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + elif self.route_method == 'post-route-dp': + beam_scores, expert_route, beam_idx = self.dp_search(current_scores_log, beam_scores, expert_route, batch_size) + + # beam_scores torch.Size([bz*num_beam]) + # expert_route torch.Size([bz*num_beam, layer_n]) + current_select_expert = expert_route[:,-1] + # current_select_expert torch.Size([bz*num_beam, 1]) + + if self.layer_judge == 'first': + replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_experts, batch_size, self.num_beams, num_tokens, self.hidden_size) + candidate_output_raw = replicated_tensor.contiguous().view(self.num_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768]) + expert_select_matrix = F.one_hot(current_select_expert, self.num_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores[beam_idx] * expert_select_matrix + output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + final_output = torch.sum(output, dim=1) + + return final_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): + """ + if first_layer: x [bz, 32, 768] + else: x [bz*num_beams, 32, 768] + + """ + if self.route_method == 'pre-route': + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) + elif self.route_method in ['post-route', 'post-route-dp']: + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route(x, beam_scores, expert_route, use_log=True) + else: + assert("route method should in pre-route, post-route, post-route-dp") + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + + diff --git a/minigpt4/models/moe/test_moe_layer.py b/minigpt4/models/moe/test_moe_layer.py new file mode 100644 index 0000000..5253340 --- /dev/null +++ b/minigpt4/models/moe/test_moe_layer.py @@ -0,0 +1,294 @@ +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + +class MoELayer(nn.Module): + def __init__(self, hidden_size, expert, num_experts, route_method, topk=1, use_balance_loss=True, weight_type='raw_prob, topk(softmax)'): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.route_method = route_method + self.topk = topk + self.use_balance_loss = use_balance_loss + self.weight_type = weight_type + + if route_method in ["gate-token", "gate-sentence"]: + self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() + elif route_method in ["gate-sentence-post"]: + gate = nn.Linear(hidden_size, 1, bias=False).float() + # self.gates = nn.ModuleList([copy.deepcopy(gate) for i in range(num_experts)]) + self.gate = gate + else: + raise KeyError("Routing method not supported.") + + def _balancing_loss(self, prob_gate, num_tokens): + # From MOEBERT + # compute the load balancing loss + # prob_gate,是 [bz, num_expert],每个样本被分配给每个expert的概率 + # 等价于 VMOE 中 _gshard_auxiliary_loss + P = prob_gate.mean(0) # torch.Size([num_expert]) 每个expert被分配到样本的平均概率 + temp = num_tokens.float() + f = temp / temp.sum(0, keepdim=True) # 每个expert被分配的sample比例 + balance_loss = self.num_experts * torch.sum(P * f) + return balance_loss + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + def _forward_gate_token(self, x): + bsz, seq_len, dim = x.size() + + x = x.view(-1, dim) + logits_gate = self.gate(x) + prob_gate = F.softmax(logits_gate, dim=-1) + gate = torch.argmax(prob_gate, dim=-1) + + order = gate.argsort(0) + num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) + gate_load = num_tokens.clone() + x = x[order] # reorder according to expert number + x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts + + # compute the load balancing loss + P = prob_gate.mean(0) + temp = num_tokens.float() + f = temp / temp.sum(0, keepdim=True) + balance_loss = self.num_experts * torch.sum(P * f) + + prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) + prob_gate = prob_gate[order] + prob_gate = prob_gate.split(num_tokens.tolist(), dim=0) + + def forward_expert(input_x, prob_x, expert_idx): + input_x = self.experts[expert_idx].forward(input_x) + input_x = input_x * prob_x + return input_x + + x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)] + x = torch.vstack(x) + x = x[order.argsort(0)] # restore original order + x = x.view(bsz, seq_len, dim) + + return x, balance_loss, gate_load + + def _forward_gate_sentence_post(self, x, attention_mask): + """ + x: query_attention_output; torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + bz = 4 + x = torch.randn(bz,32,768) + attention_mask = torch.ones([bz, 32]) + + """ + # Prepare Input x + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + + # FeedForward(x) & Forward Gate + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_experts): + output_x = self.experts[expert_idx].forward(x_masked) + outputs.append(output_x.unsqueeze(0)) + + output_x_aver = torch.mean(output_x, dim=1) + # gate_acore = self.gates[expert_idx](output_x_aver) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + candidate_output = torch.cat(outputs) # torch.Size([num_expert, bz, 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz, num_expert]) + + # Probabilities for each sample of what expert it should be sent to. + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz, num_experts]) + if 'softmax(topk)' in self.weight_type: + prob_gate1, gate = torch.topk(logits_gate, self.topk, dim=1) + select_prob_gate = F.softmax(prob_gate1, dim=-1) + else: + select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + + # Calculate Balancing Loss + if self.use_balance_loss: + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + balance_loss = self._balancing_loss(prob_gate, num_sentences) + else: + balance_loss = 0.0 + # Calculate Importance Loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + # Reshap Prob_gate & Gate + # expert_mask: [batch_size, topk, num_experts] + # expert_gate: [batch_size, topk, num_experts] + # combine_tensor: [batch_size, num_experts] + expert_mask = F.one_hot(gate, self.num_experts) + expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask + combine_tensor = torch.sum(expert_gate, dim=1) + # combine_tensor = torch.zeros_like(prob_gate) + # combine_tensor.scatter_(1, gate, select_prob_gate) # 等价操作,但可能不可导 + + candidate_output_ad = torch.permute(candidate_output, (1, 0, 2, 3)) # torch.Size([bz, num_expert, 32, 768]) + results = candidate_output_ad * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([bz, num_expert, 32, 768]) + outputs = torch.sum(results, dim=1) # torch.Size([bz, 32, 768]) + import pdb;pdb.set_trace() + + return outputs, (balance_loss+importance_loss), combine_tensor + + def pre_router(self, x, attention_mask): + # Prepare input x + attention_mask = torch.ones(attention_mask.shape[0], attention_mask.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz, 32, 768]) + x_average = torch.mean(x_masked, dim=1) # torch.Size([bz, 768]) + + # Forward Gate + # logits_gate: [bz, num_experts] + logits_gate = self.gate(x_average) + + # Probabilities for each sample of what expert it should be sent to. + # prob_gate: [bz, num_experts] + prob_gate = F.softmax(logits_gate, dim=-1) + + if 'softmax(topk)' in self.weight_type: + prob_gate1, gate = torch.topk(logits_gate, self.topk, dim=1) + select_prob_gate = F.softmax(prob_gate1, dim=-1) + else: + # topk(softmax) + # Get Top-K experts for each sample + # gate: [bz, topk] + # select_prob_gate: [bz, topk] + select_prob_gate, gate = torch.topk(prob_gate, self.topk, dim=1) + + # Reshap Prob_gate & Gate + # expert_mask: [batch_size, topk, num_experts] + # expert_gate: [batch_size, topk, num_experts] + # combine_tensor: [batch_size, num_experts] + expert_mask = F.one_hot(gate, self.num_experts) + expert_gate = select_prob_gate.unsqueeze(-1) * expert_mask + combine_tensor = torch.sum(expert_gate, dim=1) + + # Calculate Balancing Loss + if self.use_balance_loss: + num_sentences = F.one_hot(gate, self.num_experts).sum(1).gt(0).sum(0) # 每个expert被分配的样本数 torch.Size([num_expert]) + balance_loss = self._balancing_loss(prob_gate, num_sentences) + else: + balance_loss = 0.0 + + # Calculate Importance Loss + importance_loss = self._importance_auxiliary_loss(prob_gate) + + import pdb; pdb.set_trace() + + return expert_mask, combine_tensor, balance_loss, importance_loss + + def _forward_gate_sentence(self, x, attention_mask): + """ + x: query_attention_output , torch.Size([bz, 32, 768]) + attention_mask: torch.ones([bz, 32]) + + ### Notice: + the raw version of expert_attention_mask is the extended_attention_mask, + which will be add to attention_score directly + the values of extended_attention_mask are -0.0 or -10000 + it should be adjust to 1/0 version to be processed by experts + """ + # Forward Router + expert_mask, combine_tensor, balance_loss, importance_loss = self.pre_router(x, attention_mask) + + # Forward Expert FFN + result = [] + for expert_idx in range(self.num_experts): + output_x = self.experts[expert_idx].forward(x) + result.append(output_x.unsqueeze(0)) + expert_output = torch.cat(result).permute(1,0,2,3) # torch.Size([batch_size, num_expert, num_tokens, hidden_states]) + + # multiply outputs of experts by the routing probability + expert_outputs_combined = expert_output * combine_tensor.unsqueeze(-1).unsqueeze(-1) # torch.Size([batch_size, num_expert, num_tokens, hidden_states]) + outputs = torch.sum(expert_outputs_combined, dim=1) # torch.Size([batch_size, num_tokens, hidden_states]) + + import pdb; pdb.set_trace() + + return outputs, (balance_loss+importance_loss), combine_tensor + + + def forward(self, x, attention_mask): + if self.route_method == "gate-token": + x, balance_loss, gate_load = self._forward_gate_token(x) + elif self.route_method == "gate-sentence": + x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask) + elif self.route_method == "gate-sentence-post": + x, balance_loss, gate_load = self._forward_gate_sentence_post(x, attention_mask) + else: + raise KeyError("Routing method not supported.") + # import pdb; pdb.set_trace() + return x, balance_loss, gate_load + +if __name__ == '__main__': + + import sys + sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") + from minigpt4.models.QformerRouteMoE import BertConfig + from minigpt4.models.QformerRouteMoE import FeedForward + from minigpt4.models.moe.utils import ( + moe_layer_judge, + ) + + vision_width = 1408 + cross_attention_freq = 2 + num_query_token = 32 + # init_QformerMoE + config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") + config.encoder_width = vision_width + # insert cross-attention layer every other block + config.add_cross_attention = True + config.cross_attention_freq = cross_attention_freq + config.query_length = num_query_token + config.moebert_expert_num = 3 + config.moebert_num_beams = 2 + config.moebert_route_method = 'gate-sentence-post' + config.moe_topk = 1 + config.use_balance_loss = False + # config.moe_weight_type = 'raw_prob, softmax(topk)' + config.moe_weight_type = 'raw_prob, topk(softmax)' + + batch_size = 4 + x2 = torch.randn(batch_size, 32, 768) + beam_scores, expert_route = None, None + + for layer_num in [6, 8, 10]: + layer_judge = moe_layer_judge(layer_num) + ffn = FeedForward(config) + gate = nn.Linear(768, config.moebert_expert_num, bias=False).float() + + experts_moe = MoELayer( + hidden_size=config.hidden_size, + expert=ffn, + num_experts=config.moebert_expert_num, + route_method=config.moebert_route_method, + topk=config.moe_topk, + use_balance_loss=config.use_balance_loss, + weight_type=config.moe_weight_type, + ) + attn_mask = torch.ones([batch_size, 32]) + layer_output = experts_moe(x2, attn_mask) + hidden_states3, aux_loss, combine_tensor = layer_output + + print(combine_tensor) + print(aux_loss) + x2 = hidden_states3 + + print("------------------------------------") + import pdb; pdb.set_trace() \ No newline at end of file diff --git a/minigpt4/models/moe/uniroute_moe_layer.py b/minigpt4/models/moe/uniroute_moe_layer.py new file mode 100644 index 0000000..5cd2069 --- /dev/null +++ b/minigpt4/models/moe/uniroute_moe_layer.py @@ -0,0 +1,234 @@ +import copy +import pickle +import torch +import torch.nn as nn +import torch.nn.functional as F + +class UniRouteMoELayer(nn.Module): + def __init__(self, hidden_size, expert, num_experts, num_beams=2, layer_judge=None, route_method="pre-route", weight_type="ffn_prob"): + # remove hash list + nn.Module.__init__(self) + self.num_experts = num_experts #(1+other) + self.num_route_experts = num_experts-1 + self.num_beams = num_beams + self.num_route_beam = num_beams-1 + + self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) + self.hidden_size = hidden_size + self.layer_judge = layer_judge + self.weight_type = weight_type + + self.route_method = route_method + if self.route_method == "pre-route-uni": + self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() + elif self.route_method in ["post-route-uni"]: + gate = nn.Linear(hidden_size, 1, bias=False).float() + self.gate = gate + + def _importance_auxiliary_loss(self, prob_gate): + # From VMOE + # _importance_auxiliary_loss + axis = tuple(range(prob_gate.ndim - 1)) # All except last. + importance_per_expert = torch.sum(prob_gate, dim=axis) + std_importance_per_expert = torch.std(importance_per_expert) + mean_importance_per_expert = torch.mean(importance_per_expert) + # Compute coefficient of variation (i.e. std/mean) squared. + return (std_importance_per_expert / mean_importance_per_expert)**2 + + + def beam_search(self, current_scores_log, beam_scores, expert_route, batch_size): + if self.layer_judge=='first' and self.route_method in ['pre-route-uni', 'post-route-uni']: + # current_scores_log torch.Size([bz, num_experts-1]) + assert beam_scores==None and expert_route==None + current_scores = torch.exp(current_scores_log) + topk_values, gate = torch.topk(current_scores, self.num_route_beam, dim=1) # gate, 每个样本被分配的expert: torch.Size([bz, topk]) + beam_scores = topk_values.view(self.num_route_beam * batch_size) # torch.Size([bz * num_beams]) + expert_route = gate.view(self.num_route_beam * batch_size).unsqueeze(1) # torch.Size([bz * num_beams,1]) + beam_idx = torch.tensor(range(self.num_route_beam * batch_size)) + + else: + batch_size = int(batch_size // self.num_route_beam) + next_scores_raw = current_scores_log + torch.log(beam_scores).unsqueeze(1) # torch.Size([4*3, 5]) # 取log 之后,可以直接相加概率 + next_scores_exp = torch.exp(next_scores_raw) + + next_scores_raw1 = next_scores_exp.view( + batch_size, self.num_route_beam * self.num_route_experts + ) # torch.Size([bz, num_route_beam*num_route_experts]) + + next_scores, next_experts = torch.topk(next_scores_raw1, self.num_route_beam, dim=1, largest=True, sorted=True) + # next_tokens torch.Size([bz, num_route_beam]) + + next_batch_beam = list() + for batch_idx in range(batch_size): + next_sent_beam = list() + for rank, (expert_id, expert_score) in enumerate( + zip(next_experts[batch_idx], next_scores[batch_idx]) + ): + expert_id = expert_id.item() + beam_id = expert_id // self.num_route_experts + ex_id = expert_id % self.num_route_experts + effective_beam_id = batch_idx*self.num_route_beam + beam_id + + next_sent_beam.append((expert_score, ex_id, effective_beam_id)) + next_batch_beam.extend(next_sent_beam) + + beam_scores = beam_scores.new([x[0] for x in next_batch_beam]) + beam_experts = expert_route[:,-1].new([x[1] for x in next_batch_beam]) + beam_idx = expert_route[:,-1].new([x[2] for x in next_batch_beam]) + pre_route = expert_route[beam_idx,:] + expert_route = torch.cat([pre_route, beam_experts.unsqueeze(1)], dim=-1) + + return beam_scores, expert_route, beam_idx + + + def forward_gate(self, x): + """ + TODO: Pre forward gate + x : torch.Size([bz*(num_beams-1), 32, 768]) or torch.Size([bz, 32, 768]) + prob_gate : torch.Size([bz*(num_beams-1), num_experts]) or torch.Size([bz, num_experts]) + """ + attention_mask = torch.ones(x.shape[0], x.shape[1]).to(x.device) + x_masked = x * attention_mask.unsqueeze(-1) # torch.Size([bz*(num_beams-1), 32, 768]) + x_average = torch.mean(x_masked, dim=1) # torch.Size([bz*(num_beams-1), 768]) + logits_gate = self.gate(x_average) # torch.Size([bz*(num_beams-1), num_experts]) + prob_gate = F.softmax(logits_gate, dim=-1) # torch.Size([bz*(num_beams-1), num_experts]) + return prob_gate + + def forward_expert_ffn(self, x, expert_select, current_scores): + """ + x_repeat : [bz*num_beams, 32,768] + expert_select : [bz*num_beams] + current_scores : [bz*num_beams, num_experts] / [bz, num_experts] + """ + # import pdb;pdb.set_trace() + outputs = list() + for i in range(self.num_experts-1): + output_x = self.experts[i].forward(x) + outputs.append(output_x.unsqueeze(1)) + candidate_output = torch.cat(outputs, dim=1) + expert_select_matrix = F.one_hot(expert_select, self.num_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores * expert_select_matrix + candidate_output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + candidate_output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + output = torch.sum(candidate_output, dim=1) + # import pdb;pdb.set_trace() + return output # torch.Size([bz*(num_beams-1), 32, 768]) + + def forward_pre_route(self, x, beam_scores, expert_route, use_log=True): + + current_scores = self.forward_gate(x) # [bz, num_beams] / [bz*num_beams, num_beams] + + importance_loss = self._importance_auxiliary_loss(current_scores) + + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 + else: + current_scores_log = current_scores + # import pdb;pdb.set_trace() + batch_size, num_tokens = x.shape[0], x.shape[1] + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, batch_size) + current_expert_select = expert_route[:,-1] + + if self.layer_judge=='first': # expand first dim to batch_size * num_beams + replicated_tensor = x.unsqueeze(1).expand(batch_size, self.num_beams, num_tokens, self.hidden_size) + x = replicated_tensor.contiguous().view(-1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_beams, self.num_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_experts) # [bz*num_beams, num_experts] + + input_x = x[beam_idx] + candidate_output = self.forward_expert_ffn(input_x, current_expert_select, current_scores) # [bz*num_beams, 32,768] + # import pdb;pdb.set_trace() + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward_post_route_uni(self, x, beam_scores, expert_route, use_log=True): + + if beam_scores == None: + batch_size = x.shape[0] + x_masked, x_uniexpert = x, x # torch.Size([bz, 32, 768]) + elif x.shape[0]/self.num_beams == beam_scores.shape[0]/self.num_route_beam: + batch_size = int(x.shape[0]/self.num_beams) + select_universal = [i*self.num_beams+self.num_route_beam for i in range(batch_size)] + select_expert = [ x for x in range(batch_size*self.num_beams) if x not in select_universal] + x_masked, x_uniexpert = x[select_expert],x[select_universal] + num_tokens = x.shape[1] + + def forward_expert(input_x, expert_idx): + output_x = self.experts[expert_idx].forward(input_x) + return output_x + + #################### + ### route expert + #################### + outputs = list() + logits_gate_lst = list() + for expert_idx in range(self.num_route_experts): # num_expert-1 + output_x = forward_expert(x_masked, expert_idx) + output_x_aver = torch.mean(output_x, dim=1) + gate_score = self.gate(output_x_aver) + logits_gate_lst.append(gate_score) + outputs.append(output_x.unsqueeze(0)) + + candidate_output_raw = torch.cat(outputs) # torch.Size([num_expert-1, bz*(num_beam-1), 32, 768]) + logits_gate = torch.cat(logits_gate_lst,dim=1)# torch.Size([bz*(num_beam-1), num_expert-1]) + current_scores = F.softmax(logits_gate, dim=-1) # torch.Size([bz*(num_beam-1), num_expert-1]) + if use_log: + current_scores_log = torch.log(current_scores) # 取log之后可以直接相加 torch.Size([bz*(num_beam-1), num_expert-1]) + else: + current_scores_log = current_scores + + importance_loss = self._importance_auxiliary_loss(current_scores) + beam_scores, expert_route, beam_idx = self.beam_search(current_scores_log, beam_scores, expert_route, current_scores_log.shape[0]) + # beam_scores torch.Size([bz*(num_beam-1)]), expert_route torch.Size([bz*(num_beam-1), layer_n]) + current_select_expert = expert_route[:,-1] # torch.Size([bz*(num_beam-1)]) + + if self.layer_judge == 'first': + replicated_tensor = candidate_output_raw.unsqueeze(2).expand(self.num_route_experts, batch_size, self.num_route_beam, num_tokens, self.hidden_size) + candidate_output_raw = replicated_tensor.contiguous().view(self.num_route_experts, -1, num_tokens, self.hidden_size) # [bz*num_beams, 32,768] + current_scores_t = current_scores.unsqueeze(1).expand(batch_size, self.num_route_beam, self.num_route_experts) + current_scores = current_scores_t.contiguous().view(-1, self.num_route_experts) # [bz*(num_beams-1), num_experts-1] + + candidate_output = candidate_output_raw.permute(1, 0, 2, 3)[beam_idx] # torch.Size([8, 2, 32, 768]) + expert_select_matrix = F.one_hot(current_select_expert, self.num_route_experts) + if self.weight_type == 'ffn_prob': + tmp_prob = current_scores[beam_idx] * expert_select_matrix + output = candidate_output * tmp_prob.unsqueeze(-1).unsqueeze(-1) + else: + output = candidate_output * expert_select_matrix.unsqueeze(-1).unsqueeze(-1) + experts_output = torch.sum(output, dim=1) # [bz*num_beams-1, 32, 768] + + # import pdb; pdb.set_trace() + + #################### + ### universal expert + #################### + uni_output = forward_expert(x_uniexpert, self.num_experts-1) # [bz, 32, 768] + + #################### + ### Combine expert + #################### + output = list() + for i in range(batch_size): + expert_tmp = experts_output[i*self.num_route_beam: i*self.num_route_beam+self.num_route_beam,:,:] + combine_tmp = torch.cat((expert_tmp, uni_output[i].unsqueeze(0))) + output.append(combine_tmp) + final_output = torch.cat(output) # [bz*num_beam, 32 ,768] + + # import pdb; pdb.set_trace() + + return final_output, beam_scores, expert_route, beam_idx, importance_loss + + def forward(self, x, attention_mask, beam_scores, expert_route, use_log=True): + """ + if first_layer: x [bz, 32, 768] + else: x [bz*num_beams, 32, 768] + """ + if self.route_method == 'pre-route-uni': + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_pre_route(x, beam_scores, expert_route, use_log=True) + elif self.route_method in ['post-route-uni']: + candidate_output, beam_scores, expert_route, beam_idx, importance_loss = self.forward_post_route_uni(x, beam_scores, expert_route, use_log=True) + + return candidate_output, beam_scores, expert_route, beam_idx, importance_loss + + diff --git a/minigpt4/models/moe/utils.py b/minigpt4/models/moe/utils.py new file mode 100644 index 0000000..1489f60 --- /dev/null +++ b/minigpt4/models/moe/utils.py @@ -0,0 +1,171 @@ +import numpy as np +import pickle +import torch +import torch.nn as nn + +from dataclasses import dataclass +from torch import Tensor +from transformers.activations import ACT2FN +from transformers.file_utils import ModelOutput +from typing import Optional, Tuple, List + + +def use_experts(layer_idx): + # if layer_idx % 2 == 0: + # use moe_ffn after cross_attns + if int(layer_idx) in [6,7,8,9,10,11]: + # layer 6/8/10 + return True + else: + return False + +def use_experts_route(layer_idx): + # if layer_idx % 2 == 0: + # use moe_ffn after cross_attns + # if int(layer_idx) in [0,2,4,6,8,10]: + if int(layer_idx) in [6,7,8,9,10,11]: + return True + else: + return False + +def moe_layer_judge(layer_idx): + if layer_idx == 6: + return 'first' + elif layer_idx in [7,8,9,10]: + return 'mid' + elif layer_idx == 11: + return 'last' + else: + return None + + # if layer_idx == 0: + # return 'first' + # elif layer_idx in [2,4,6,8]: + # return 'mid' + # elif layer_idx == 10: + # return 'last' + # else: + # return None + +def process_ffn(model): + if model.config.model_type == "bert": + inner_model = model.bert + else: + raise ValueError("Model type not recognized.") + + for i in range(model.config.num_hidden_layers): + model_layer = inner_model.encoder.layer[i] + if model_layer.use_experts: + model_layer.importance_processor.load_experts(model_layer) + + +class ImportanceProcessor: + def __init__(self, config, layer_idx, num_local_experts, local_group_rank): + self.num_experts = config.moebert_expert_num # total number of experts + self.num_local_experts = num_local_experts # number of experts on this device + self.local_group_rank = local_group_rank # rank in the current process group + self.intermediate_size = config.moebert_expert_dim # FFN hidden dimension + self.share_importance = config.moebert_share_importance # number of shared FFN dimension + + importance = ImportanceProcessor.load_importance_single(config.moebert_load_importance)[layer_idx, :] + self.importance = self._split_importance(importance) + + self.is_moe = False # safety check + + @staticmethod + def load_importance_single(importance_files): + with open(importance_files, "rb") as file: + data = pickle.load(file) + data = data["idx"] + return np.array(data) + + def _split_importance(self, arr): + result = [] + top_importance = arr[:self.share_importance] + remain = arr[self.share_importance:] + all_experts_remain = [] + for i in range(self.num_experts): + all_experts_remain.append(remain[i::self.num_experts]) + all_experts_remain = np.array(all_experts_remain) + + for i in range(self.num_local_experts): + temp = all_experts_remain[self.num_local_experts * self.local_group_rank + i] + temp = np.concatenate((top_importance, temp)) + temp = temp[:self.intermediate_size] + result.append(temp.copy()) + result = np.array(result) + return result + + def load_experts(self, model_layer): + expert_list = model_layer.experts.experts + fc1_weight_data = model_layer.intermediate.dense.weight.data + fc1_bias_data = model_layer.intermediate.dense.bias.data + fc2_weight_data = model_layer.output.dense.weight.data + fc2_bias_data = model_layer.output.dense.bias.data + layernorm_weight_data = model_layer.output.LayerNorm.weight.data + layernorm_bias_data = model_layer.output.LayerNorm.bias.data + for i in range(self.num_local_experts): + idx = self.importance[i] + expert_list[i].fc1.weight.data = fc1_weight_data[idx, :].clone() + expert_list[i].fc1.bias.data = fc1_bias_data[idx].clone() + expert_list[i].fc2.weight.data = fc2_weight_data[:, idx].clone() + expert_list[i].fc2.bias.data = fc2_bias_data.clone() + expert_list[i].LayerNorm.weight.data = layernorm_weight_data.clone() + expert_list[i].LayerNorm.bias.data = layernorm_bias_data.clone() + del model_layer.intermediate + del model_layer.output + self.is_moe = True + + +class FeedForward(nn.Module): + def __init__(self, config, intermediate_size, dropout): + nn.Module.__init__(self) + + # first layer + self.fc1 = nn.Linear(config.hidden_size, intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + # second layer + self.fc2 = nn.Linear(intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(dropout) + + def forward(self, hidden_states: Tensor): + input_tensor = hidden_states + hidden_states = self.fc1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +@dataclass +class MoEModelOutput(ModelOutput): + last_hidden_state: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + gate_loss: torch.FloatTensor = None + gate_loads: Optional[Tuple[torch.FloatTensor]] = None + beam_scores: Optional[Tuple[torch.FloatTensor]] = None + expert_route: Optional[Tuple[torch.FloatTensor]] = None + + + +@dataclass +class MoEModelOutputWithPooling(ModelOutput): + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + gate_loss: torch.FloatTensor = None + gate_loads: Optional[Tuple[torch.FloatTensor]] = None + beam_scores: Optional[Tuple[torch.FloatTensor]] = None + expert_route: Optional[Tuple[torch.FloatTensor]] = None diff --git a/minigpt4/projects/minigpt/eval/minigpt4_eval.yaml b/minigpt4/projects/minigpt/eval/minigpt4_eval.yaml new file mode 100644 index 0000000..365c42f --- /dev/null +++ b/minigpt4/projects/minigpt/eval/minigpt4_eval.yaml @@ -0,0 +1,22 @@ +model: + arch: minigpt4 + model_type: pretrain_vicuna0 + max_txt_len: 160 + end_sym: "###" + low_resource: True + prompt_template: '###Human: {} ###Assistant: ' + ckpt: '/mnt/pfs-guan-ssai/nlu/wanghanzi/models/minigpt4/prerained_minigpt4_7b.pth' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/minigpt4/projects/minigpt/eval/minigpt4_llama2_eval.yaml b/minigpt4/projects/minigpt/eval/minigpt4_llama2_eval.yaml new file mode 100644 index 0000000..93efab1 --- /dev/null +++ b/minigpt4/projects/minigpt/eval/minigpt4_llama2_eval.yaml @@ -0,0 +1,22 @@ +model: + arch: minigpt4 + model_type: pretrain_llama2 + max_txt_len: 160 + end_sym: "" + low_resource: True + prompt_template: '[INST] {} [/INST] ' + ckpt: 'please set this value to the path of pretrained checkpoint' + + +datasets: + cc_sbu_align: + vis_processor: + train: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +run: + task: image_text_pretrain diff --git a/eval_configs/minigptv2_eval.yaml b/minigpt4/projects/minigpt/eval/minigptv2_eval.yaml similarity index 82% rename from eval_configs/minigptv2_eval.yaml rename to minigpt4/projects/minigpt/eval/minigptv2_eval.yaml index 0479f2a..00f3604 100644 --- a/eval_configs/minigptv2_eval.yaml +++ b/minigpt4/projects/minigpt/eval/minigptv2_eval.yaml @@ -5,7 +5,7 @@ model: end_sym: "" low_resource: True prompt_template: '[INST] {} [/INST]' - ckpt: 'please set this value to the path of pretrained checkpoint' + ckpt: '/mnt/pfs-guan-ssai/nlu/wanghanzi/models/minigptv2/minigptv2_checkpoint.pth' lora_r: 64 lora_alpha: 16 diff --git a/train_configs/minigpt4_llama2_stage1_pretrain.yaml b/minigpt4/projects/minigpt/train/minigpt4_llama2_stage1_pretrain.yaml similarity index 100% rename from train_configs/minigpt4_llama2_stage1_pretrain.yaml rename to minigpt4/projects/minigpt/train/minigpt4_llama2_stage1_pretrain.yaml diff --git a/train_configs/minigpt4_llama2_stage2_finetune.yaml b/minigpt4/projects/minigpt/train/minigpt4_llama2_stage2_finetune.yaml similarity index 100% rename from train_configs/minigpt4_llama2_stage2_finetune.yaml rename to minigpt4/projects/minigpt/train/minigpt4_llama2_stage2_finetune.yaml diff --git a/train_configs/minigpt4_stage1_pretrain.yaml b/minigpt4/projects/minigpt/train/minigpt4_stage1_pretrain.yaml similarity index 100% rename from train_configs/minigpt4_stage1_pretrain.yaml rename to minigpt4/projects/minigpt/train/minigpt4_stage1_pretrain.yaml diff --git a/train_configs/minigpt4_stage2_finetune.yaml b/minigpt4/projects/minigpt/train/minigpt4_stage2_finetune.yaml similarity index 100% rename from train_configs/minigpt4_stage2_finetune.yaml rename to minigpt4/projects/minigpt/train/minigpt4_stage2_finetune.yaml diff --git a/minigpt4/projects/minigpt/train/minigptv2_finetune.yaml b/minigpt4/projects/minigpt/train/minigptv2_finetune.yaml new file mode 100644 index 0000000..85dd549 --- /dev/null +++ b/minigpt4/projects/minigpt/train/minigptv2_finetune.yaml @@ -0,0 +1,293 @@ +model: + arch: minigpt_v2 + model_type: pretrain + max_txt_len: 1024 + image_size: 448 + end_sym: "" + llama_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/llama_2_7b_chat" + ckpt: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/minigptv2/checkpoint_stage2.pth" + use_grad_checkpoint: True + chat_template: True + lora_r: 64 + lora_alpha: 16 + +datasets: + # multitask_conversation: # in-house data 12171 + # batch_size: 2 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 50 + + llava_conversation: # 56681 + batch_size: 2 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 30 + + unnatural_instruction: # pure text 65852 + batch_size: 1 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 10 + + + # refvg: # [refer] return the location + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 40 + + llava_detail: # 23240 + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 20 + + llava_reason: # 76643 + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 80 + + + # flickr_grounded_caption: # [grounding] : TODO + # batch_size: 2 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 80 + + # flickr_CaptionToPhrase: # [detection] + # batch_size: 2 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 80 + + # flickr_ObjectToPhrase: # [detection] + # batch_size: 2 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 80 + + coco_caption: # 414113 train + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 10 + + + textcaps_caption: # 109765 train + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 30 + + # refcoco: + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 25 + + # refcocop: + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 25 + + # refcocog: + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 25 + + # invrefcoco: + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 10 + + # invrefcocop: + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 10 + + # invrefcocog: + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 10 + + + coco_vqa: # 658104 + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 15 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 8 + + aok_vqa: # (17056, 1145, 6702) + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 12 + + gqa: # (943000, 12578, 12578) + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 50 + + ocrvqa: # 207572 + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 30 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-5 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 1 + num_workers: 6 + warmup_steps: 1000 + iters_per_epoch: 1000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/minigpt/v2/vqa_pretrain_3B_llama2_7b_chat_stage3_train_linear_lora_test_1030" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + test_splits: ["test"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + wandb_log: True + job_name: minigptv2_finetune \ No newline at end of file diff --git a/minigpt4/projects/minigpt/train/minigptv2_finetune_vqa.yaml b/minigpt4/projects/minigpt/train/minigptv2_finetune_vqa.yaml new file mode 100644 index 0000000..3b3fe74 --- /dev/null +++ b/minigpt4/projects/minigpt/train/minigptv2_finetune_vqa.yaml @@ -0,0 +1,171 @@ +model: + arch: minigpt_v2 + model_type: pretrain + max_txt_len: 1024 + image_size: 448 + end_sym: "" + llama_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/llama_2_7b_chat" + ckpt: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/minigptv2/checkpoint_stage2.pth" + use_grad_checkpoint: True + chat_template: True + lora_r: 64 + lora_alpha: 16 + +datasets: + + # llava_conversation: # 56681 + # batch_size: 2 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 30 + + # unnatural_instruction: # pure text 65852 + # batch_size: 1 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 10 + + # llava_detail: # 23240 + # batch_size: 4 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 20 + + llava_reason: # 76643 + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 80 + + # coco_caption: # 414113 train + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 10 + + + # textcaps_caption: # 109765 train + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 30 + + # coco_vqa: # 658104 + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 15 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 448 + text_processor: + train: + name: "blip_caption" + sample_ratio: 8 + + # aok_vqa: # train: 17056, val: 1145 + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 12 + + # gqa: # train: 943000, 12578, 12578) + # type: balanced_sft_raw + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 50 + + # ocrvqa: # train 207572 + # batch_size: 6 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 448 + # text_processor: + # train: + # name: "blip_caption" + # sample_ratio: 30 + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-5 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 1 + num_workers: 6 + warmup_steps: 1000 + iters_per_epoch: 1000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/minigpt/v2/vqa_pretrain_3B_llama2_7b_chat_stage3_train_linear_lora_test_1030" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + test_splits: ["test"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + wandb_log: False + job_name: minigptv2_finetune \ No newline at end of file diff --git a/minigpt4/projects/prompt_moe/eval/gqa_llava_prompt_flant5xxl_eval_1010.yaml b/minigpt4/projects/prompt_moe/eval/gqa_llava_prompt_flant5xxl_eval_1010.yaml new file mode 100644 index 0000000..b496f88 --- /dev/null +++ b/minigpt4/projects/prompt_moe/eval/gqa_llava_prompt_flant5xxl_eval_1010.yaml @@ -0,0 +1,83 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_freeze_qf_train_qt_gate_textt5_textinqf_epo3_1009/20231009234/checkpoint_best.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: True + freeze_t5_proj: True + + # moe + repeat_to_init_qt_candidates: True + num_qt_candidates: 5 + + gate_save_file: "/mnt/pfs-guan-ssai/nlu/wanghanzi/evaluation/BLIP2/GQA/llava_st_257k_raw_freeze_qf_train_qt_gate_textt5_textinf_epo3_1010/gate.txt" + + +datasets: + gqa: # name of the dataset builder + type: balanced_testdev + vis_processor: + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_question" + build_info: + images: + storage: "/mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images/" + +run: + task: gqa + batch_size_train: 4 + batch_size_eval: 16 + num_workers: 4 + + max_len: 20 + min_len: 1 + num_beams: 5 + inference_method: "generate" + # prompt: "Question: {} Short answer:" + prompt: "" + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/evaluation/BLIP2/GQA/llava_st_257k_raw_freeze_qf_train_qt_gate_textt5_textinf_epo3_1010/" + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/prompt_moe/eval/gqa_llava_prompt_flant5xxl_eval_1012.yaml b/minigpt4/projects/prompt_moe/eval/gqa_llava_prompt_flant5xxl_eval_1012.yaml new file mode 100644 index 0000000..4cfaf49 --- /dev/null +++ b/minigpt4/projects/prompt_moe/eval/gqa_llava_prompt_flant5xxl_eval_1012.yaml @@ -0,0 +1,87 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + # pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_train_qf_train_qt_linear_gate_textt5_20ex_3loss_textinqf_epo3_1012/20231014175/checkpoint_best.pth" + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/instruct_blip_flant5xxl/instruct_blip_flanxxl_trimmed.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + repeat_to_init_qt_candidates: True + num_qt_candidates: 20 + moe_topk: 2 + text_embed_key: "text_input" # "prompt"/"text_input" + eval_gate_save: False + train_gate_save: False + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_train_qf_train_qt_linear_moe_top2_gate_textt5_textinqf_epo3_1012_test/" + + + +datasets: + gqa: # name of the dataset builder + type: balanced_testdev + vis_processor: + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_question" + build_info: + images: + storage: "/mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images/" + +run: + task: gqa + batch_size_train: 16 + batch_size_eval: 32 + num_workers: 4 + + max_len: 20 + min_len: 1 + num_beams: 5 + inference_method: "generate" + # prompt: "Question: {} Short answer:" + prompt: "" + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/evaluation/BLIP2/GQA/llava_st_257k_raw_train_qf_train_qt_linear_gate_textt5_20ex_3loss_textinqf_epo3_1015/" + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/prompt_moe/eval/mme_llava_prompt_flant5xxl_eval_1007.yaml b/minigpt4/projects/prompt_moe/eval/mme_llava_prompt_flant5xxl_eval_1007.yaml new file mode 100644 index 0000000..496cd14 --- /dev/null +++ b/minigpt4/projects/prompt_moe/eval/mme_llava_prompt_flant5xxl_eval_1007.yaml @@ -0,0 +1,90 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + # intialize stage 2 pretraining from stage 1 pretrained model + # pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_freeze_qf_train_qt_gate_textt5_textinqf_epo3_1009/20231009234/checkpoint_best.pth" + # pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_train_qf_train_qt_gate_textt5_textinqf_epo3_1010/20231010150/checkpoint_best.pth" + # pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_train_qf_train_qt_linear_gate_textt5_textinqf_epo3_1011/20231012002/checkpoint_best.pth" + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + repeat_to_init_qt_candidates: True + num_qt_candidates: 5 + moe_topk: 2 + eval_gate_save: False + train_gate_save: False + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/evaluation/BLIP2/MME/promt_moe_blip2_pretrain_flant5xxl_1012/" + + +datasets: + gqa: # name of the dataset builder + type: balanced_testdev + vis_processor: + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_question" + build_info: + images: + storage: "/mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images/" + +run: + task: gqa + batch_size_train: 4 + batch_size_eval: 4 + num_workers: 4 + + max_len: 20 + min_len: 1 + num_beams: 5 + inference_method: "generate" + # prompt: "Question: {} Short answer:" + prompt: "" + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/evaluation/BLIP2/MME/promt_moe_blip2_pretrain_flant5xxl_1012/" + result_output: "mme_result.csv" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/prompt_moe/eval/vqa_coco_flant5xxl_eval_1115.yaml b/minigpt4/projects/prompt_moe/eval/vqa_coco_flant5xxl_eval_1115.yaml new file mode 100644 index 0000000..bee255b --- /dev/null +++ b/minigpt4/projects/prompt_moe/eval/vqa_coco_flant5xxl_eval_1115.yaml @@ -0,0 +1,92 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + # finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_coco_gqa_1610k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1102/20231115061/checkpoint_2.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moe_position: "post" # post (position to insert PromptMoE Part) + embed_extract: "blip2_pretrain" # t5, random (way to extract embeddings of task instruction if moe_position is pre) + repeat_to_init_qt_candidates: True + num_qt_candidates: 20 + moe_topk: 2 + eval_gate_save: False + train_gate_save: False + +datasets: + coco_vqa: # 658104 + # batch_size: 6 + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 5 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_coco_gqa_1610k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1115_eval/" + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/prompt_moe/train/gqa_sft_post_prompt_moe.yaml b/minigpt4/projects/prompt_moe/train/gqa_sft_post_prompt_moe.yaml new file mode 100644 index 0000000..c622963 --- /dev/null +++ b/minigpt4/projects/prompt_moe/train/gqa_sft_post_prompt_moe.yaml @@ -0,0 +1,92 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moe_position: "post" # post (position to insert PromptMoE Part) + embed_extract: "blip2_pretrain" # t5, random (way to extract embeddings of task instruction if moe_position is pre) + repeat_to_init_qt_candidates: True + num_qt_candidates: 20 + moe_topk: 2 + eval_gate_save: True + train_gate_save: False + gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1101/" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1101/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/prompt_moe/train/gqa_sft_pre_prompt_moe_blip2_embed.yaml b/minigpt4/projects/prompt_moe/train/gqa_sft_pre_prompt_moe_blip2_embed.yaml new file mode 100644 index 0000000..2c95cb2 --- /dev/null +++ b/minigpt4/projects/prompt_moe/train/gqa_sft_pre_prompt_moe_blip2_embed.yaml @@ -0,0 +1,97 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moe_position: "pre" + embed_extract: "blip2_pretrain" # t5 + repeat_to_init_qt_candidates: True + num_qt_candidates: 20 + moe_topk: 2 + eval_gate_save: True + train_gate_save: True + gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_train_qf_train_qt_linear_gate_textblip2_20ex_top2_3loss_textinqf_epo3_1021/" + +datasets: + gqa: + type: balanced_sft_raw + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_question" + eval: + name: "blip_question" + build_info: + images: + storage: "/mnt/pfs-guan-ssai/nlu/wanghanzi/data/GQA/images/" + + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + batch_size_train: 8 + batch_size_eval: 16 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_train_qf_train_qt_linear_gate_textblip2_20ex_top2_3loss_textinqf_epo3_1021/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/prompt_moe/train/llava_prompt_moe_single_turn.yaml b/minigpt4/projects/prompt_moe/train/llava_prompt_moe_single_turn.yaml new file mode 100644 index 0000000..bfba3d5 --- /dev/null +++ b/minigpt4/projects/prompt_moe/train/llava_prompt_moe_single_turn.yaml @@ -0,0 +1,94 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + repeat_to_init_qt_candidates: True + num_qt_candidates: 20 + eval_gate_save: True + train_gate_save: True + gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_train_qf_train_qt_linear_gate_textt5_20ex_3loss_textinqf_epo3_1012/" + +datasets: + llava150k_en_sft: + type: prompt_moe + vis_processor: + train: + name: "blip_image_train" + image_size: 224 + eval: + name: "blip_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_question" + eval: + name: "blip_question" + build_info: + images: + storage: "/mnt/pfs-guan-ssai/nlu/dingyifeng/data/COCO" + + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + batch_size_train: 4 + batch_size_eval: 8 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/llava_st_257k_raw_train_qf_train_qt_linear_gate_textt5_20ex_3loss_textinqf_epo3_1012/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/prompt_moe/train/mix_coco_gqa_prompt_moe_post_blip2.yaml b/minigpt4/projects/prompt_moe/train/mix_coco_gqa_prompt_moe_post_blip2.yaml new file mode 100644 index 0000000..4692440 --- /dev/null +++ b/minigpt4/projects/prompt_moe/train/mix_coco_gqa_prompt_moe_post_blip2.yaml @@ -0,0 +1,126 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moe_position: "post" # post (position to insert PromptMoE Part) + embed_extract: "blip2_pretrain" # t5, random (way to extract embeddings of task instruction if moe_position is pre) + repeat_to_init_qt_candidates: True + num_qt_candidates: 20 + moe_topk: 2 + eval_gate_save: False + train_gate_save: False + gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_coco_gqa_1610k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1031/" + + +datasets: + # gqa: # train: 943000, 12578, 12578) + # type: balanced_sft_raw + # batch_size: 4 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 50 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + coco_vqa: # 658104 + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 5 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_coco_gqa_1610k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1031/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/prompt_moe/train/mix_llava_665k_prompt_moe_post_blip2.yaml b/minigpt4/projects/prompt_moe/train/mix_llava_665k_prompt_moe_post_blip2.yaml new file mode 100644 index 0000000..aca20d8 --- /dev/null +++ b/minigpt4/projects/prompt_moe/train/mix_llava_665k_prompt_moe_post_blip2.yaml @@ -0,0 +1,95 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_instruct_pro_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moe_position: "post" # post (position to insert PromptMoE Part) + embed_extract: "blip2_pretrain" # t5, random (way to extract embeddings of task instruction if moe_position is pre) + repeat_to_init_qt_candidates: True + num_qt_candidates: 20 + moe_topk: 2 + +datasets: + + llava_mix: # 665298 + batch_size: 8 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + + gqa: # val/test = 12578 + type: balanced_sft_raw_eval + batch_size: 4 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 5 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_llava_665k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1114/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe/eval/gqa_qformer_moe_blip2_eval.yaml b/minigpt4/projects/qformer_moe/eval/gqa_qformer_moe_blip2_eval.yaml new file mode 100644 index 0000000..0b5b2fc --- /dev/null +++ b/minigpt4/projects/qformer_moe/eval/gqa_qformer_moe_blip2_eval.yaml @@ -0,0 +1,90 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_2loss_textinqf_epo3_1107/20231107103/checkpoint_0.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moebert_expert_num: 20 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 2 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_2loss_textinqf_epo3_1107_eval/" + + amp: True + resume_ckpt_path: null + + evaluate: True + # train_splits: ["train"] + # valid_splits: ["val"] + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe/eval/mix_qformer_moe_blip2_eval.yaml b/minigpt4/projects/qformer_moe/eval/mix_qformer_moe_blip2_eval.yaml new file mode 100644 index 0000000..e43454c --- /dev/null +++ b/minigpt4/projects/qformer_moe/eval/mix_qformer_moe_blip2_eval.yaml @@ -0,0 +1,106 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_2loss_textinqf_epo3_1107/20231107103/checkpoint_0.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moebert_expert_num: 20 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 2 + +datasets: + ok_vqa: # train, valid (9009, 5046) + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + # sample_ratio: 8 + + coco_vqa: # 658104 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + # sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_2loss_textinqf_epo3_1107_eval/" + + amp: True + resume_ckpt_path: null + + evaluate: True + # train_splits: ["train"] + # valid_splits: ["val"] + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2.yaml b/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2.yaml new file mode 100644 index 0000000..51f89b6 --- /dev/null +++ b/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2.yaml @@ -0,0 +1,90 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moebert_expert_num: 20 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 2 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_QformerMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1115/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2_336.yaml b/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2_336.yaml new file mode 100644 index 0000000..8cd5ed6 --- /dev/null +++ b/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2_336.yaml @@ -0,0 +1,90 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 336 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moebert_expert_num: 20 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 2 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 336 + eval: + name: "blip2_image_eval" + image_size: 336 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_336_even_moe_ffn_QformerMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1117/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2_short_answer_224.yaml b/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2_short_answer_224.yaml new file mode 100644 index 0000000..d1a7e8f --- /dev/null +++ b/minigpt4/projects/qformer_moe/train/gqa_qformer_moe_blip2_short_answer_224.yaml @@ -0,0 +1,90 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moebert_expert_num: 5 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 1 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_short_answer_224_even_moe_ffn_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_3loss_textinqf_epo3_1117/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe/train/mix_qformer_moe_blip2.yaml b/minigpt4/projects/qformer_moe/train/mix_qformer_moe_blip2.yaml new file mode 100644 index 0000000..1ccf5d4 --- /dev/null +++ b/minigpt4/projects/qformer_moe/train/mix_qformer_moe_blip2.yaml @@ -0,0 +1,126 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + moebert_expert_num: 20 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 2 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + # batch_size: 4 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 50 + + ok_vqa: # train, valid (9009, 5046) + # batch_size: 6 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + coco_vqa: # 658104 + # batch_size: 6 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 5 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_coco_gqa_1610k_raw_QformerMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo5_1115/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_post_vicuna/eval/mix_qformer_moe_post_vicuna7b_evaluation.yaml b/minigpt4/projects/qformer_moe_post_vicuna/eval/mix_qformer_moe_post_vicuna7b_evaluation.yaml new file mode 100644 index 0000000..45cf587 --- /dev/null +++ b/minigpt4/projects/qformer_moe_post_vicuna/eval/mix_qformer_moe_post_vicuna7b_evaluation.yaml @@ -0,0 +1,105 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/mix_coco_gqa_1610k_raw_QformerMoE_Post_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0301/20240301223/checkpoint_best.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: False + moebert_expert_num: 3 + moebert_route_method: "gate-sentence-post" + moe_weight_type: "raw_prob" + moebert_load_balance: 0.05 + moe_topk: 1 + use_balance_loss: False + ln_position: "out" + gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/gate_save/mix_coco_gqa_1610k_raw_QformerMoE_Post_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0301/" + +datasets: + gqa: + type: balanced_sft_raw_eval + batch_size: 4 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + ok_vqa: # train, valid (9009, 5046) + type: ok_vqa_eval + batch_size: 4 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + coco_vqa: # 658104 + type: vqa_v2_eval + batch_size: 4 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/eval/mix_coco_gqa_1610k_raw_QformerMoE_Post_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0301/" + num_workers: 4 + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + + + + + diff --git a/minigpt4/projects/qformer_moe_post_vicuna/eval/mix_qformer_moe_post_vicuna7b_evaluation2.yaml b/minigpt4/projects/qformer_moe_post_vicuna/eval/mix_qformer_moe_post_vicuna7b_evaluation2.yaml new file mode 100644 index 0000000..937e5fd --- /dev/null +++ b/minigpt4/projects/qformer_moe_post_vicuna/eval/mix_qformer_moe_post_vicuna7b_evaluation2.yaml @@ -0,0 +1,126 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/mix_coco_gqa_aokvqa_cap_raw_QformerMoE_CLS_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0307/20240307231/checkpoint_7.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: False + moebert_expert_num: 3 + moebert_route_method: "gate-sentence-cls" + moe_weight_type: "raw_prob" + moebert_load_balance: 0.05 + moe_topk: 1 + use_balance_loss: False + ln_position: "out" + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/gate_save/lora_mix_coco_gqa_aokvqa_raw_QformerMoE_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0308/" + +datasets: + # gqa: + # type: balanced_sft_raw_eval + # batch_size: 4 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # ok_vqa: # train, valid (9009, 5046) + # type: ok_vqa_eval + # batch_size: 4 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # coco_vqa: # 658104 + # type: vqa_v2_eval + # batch_size: 4 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + # aok_vqa: # train: 17056, val: 1145 + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + +run: + task: instruction_tuning + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/eval/mix_coco_gqa_aokvqa_cap_raw_QformerMoE_CLS_linear_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0307/" + num_workers: 4 + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + + + + + diff --git a/minigpt4/projects/qformer_moe_post_vicuna/eval/vqav2_okvqa_gqa_evaluation.yaml b/minigpt4/projects/qformer_moe_post_vicuna/eval/vqav2_okvqa_gqa_evaluation.yaml new file mode 100644 index 0000000..f03cb5b --- /dev/null +++ b/minigpt4/projects/qformer_moe_post_vicuna/eval/vqav2_okvqa_gqa_evaluation.yaml @@ -0,0 +1,124 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/mix_coco_gqa_aokvqa_cap_raw_QformerMoE_lr5e5_textinqf_6epo_0307/20240308004/checkpoint_7.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: False + use_route_moe: False + # moebert_expert_num: 3 + # moebert_route_method: "gate-sentence-post" + # moe_weight_type: "raw_prob" + # moebert_load_balance: 0.05 + # moe_topk: 1 + # use_balance_loss: False + # ln_position: "out" + +datasets: + # gqa: + # type: balanced_sft_raw_eval + # batch_size: 4 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # ok_vqa: # train, valid (9009, 5046) + # type: ok_vqa_eval + # batch_size: 4 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # coco_vqa: # 658104 + # type: vqa_v2_eval + # batch_size: 4 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # aok_vqa: # train: 17056, val: 1145 + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_train" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_aokvqa_cap_raw_QformerMoE_lr5e5_textinqf_6epo_0307/" + num_workers: 4 + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + + + + + diff --git a/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_cls_route.yaml b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_cls_route.yaml new file mode 100644 index 0000000..6ccfda1 --- /dev/null +++ b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_cls_route.yaml @@ -0,0 +1,132 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: False + moebert_expert_num: 3 + moebert_route_method: "gate-sentence-cls" + moe_weight_type: "raw_prob" + moebert_load_balance: 0.05 + moe_topk: 1 + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + # ok_vqa: # train, valid (9009, 5046) + # batch_size: 16 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 1 + + # coco_vqa: # 658104 + # batch_size: 16 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 9 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 6 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_1610k_raw_QformerMoE_Post_CLS_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0305/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + wandb_log: True + job_name: mix_post_5e5_3ex1_cosine_005 \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_cosine_route.yaml b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_cosine_route.yaml new file mode 100644 index 0000000..e2a155c --- /dev/null +++ b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_cosine_route.yaml @@ -0,0 +1,132 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: False + moebert_expert_num: 3 + moebert_route_method: "gate-sentence-post-cosine" + moe_weight_type: "raw_prob" + moebert_load_balance: 0.05 + moe_topk: 1 + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 6 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_1610k_raw_QformerMoE_Post_Cosine_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0305/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + wandb_log: True + job_name: mix_post_5e5_3ex1_cosine_005 \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_data_balance.yaml b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_data_balance.yaml new file mode 100644 index 0000000..8c5e050 --- /dev/null +++ b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_data_balance.yaml @@ -0,0 +1,128 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + moebert_expert_num: 3 + moebert_route_method: "gate-sentence-post" + moebert_load_balance: 0 + moe_topk: 1 + use_balance_loss: False + moe_weight_type: 'average' + +datasets: + gqa: # train: 94254 + type: balanced_sft_raw_part + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 50 + + ok_vqa: # train, valid (9009, 5046 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + # coco_vqa: # 214352 vqa_val + # type: vqa_v2_part + # batch_size: 32 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 1 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 1000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_data_balance_finetuned.yaml b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_data_balance_finetuned.yaml new file mode 100644 index 0000000..806af92 --- /dev/null +++ b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_data_balance_finetuned.yaml @@ -0,0 +1,127 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/20231201184/checkpoint_best.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + moebert_expert_num: 5 + moebert_route_method: "gate-sentence-post" + moebert_load_balance: 0 + moe_topk: 1 + use_balance_loss: False + moe_weight_type: 'average' + +datasets: + gqa: # train: 94254 + type: balanced_sft_raw_part + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 50 + + ok_vqa: # train, valid (9009, 5046 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + coco_vqa: # 214352 vqa_val + type: vqa_v2_part + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 1 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_Post_train_qf_train_qt_aver_weight_5ex_top1_1loss_textinqf_epo3_s42_1201/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_route.yaml b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_route.yaml new file mode 100644 index 0000000..a127fc3 --- /dev/null +++ b/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_route.yaml @@ -0,0 +1,135 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + lora_r: 64 # freeze_llm = False + lora_alpha: 16 + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + # freeze_llm: True + freeze_llm: False + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: False + moebert_expert_num: 3 + moebert_route_method: "gate-sentence-post" + moe_weight_type: "raw_prob" + moebert_load_balance: 0.05 + moe_topk: 1 + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 6 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/lora_mix_coco_gqa_1610k_raw_QformerMoE_Post_gate_lnout_lr5e5_3ex_top1_2loss_005_top6layer_textinqf_6epo_0305/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + + wandb_log: True + job_name: mix_post_5e5_3ex1_cosine_005 \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml new file mode 100644 index 0000000..7ad1693 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml @@ -0,0 +1,154 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/20240314230/checkpoint_12.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + # use_moeqformer: True + # use_route_moe: False + # moebert_expert_num: 3 + # moebert_route_method: "gate-sentence-post" + # moe_weight_type: "raw_prob" + # moebert_load_balance: 0.05 + # moe_topk: 1 + # use_balance_loss: False + # ln_position: "out" + + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/" + +datasets: + # gqa: + # type: balanced_sft_raw_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # coco_vqa: # 658104 + # type: vqa_v2_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + coco_caption: # 414113 train + type: coco_cap_eval + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + # ok_vqa: # train, valid (9009, 5046) + # type: ok_vqa_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # aok_vqa: # train: 17056, val: 1145 + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 3000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/" + + amp: True + resume_ckpt_path: null + + evaluate: True + # test_splits: ["val"] + test_splits: ["test"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/eval/mix_vqa_coco_vicuna_eval_coco_vqa_test.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/eval/mix_vqa_coco_vicuna_eval_coco_vqa_test.yaml new file mode 100644 index 0000000..692f04c --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/eval/mix_vqa_coco_vicuna_eval_coco_vqa_test.yaml @@ -0,0 +1,108 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cap_raw_QformerMoE_Post_Route_linear_gate_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_8epo_0309/20240309174/checkpoint_7.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/mix_coco_gqa_aokvqa_cap_raw_QformerMoE_lr5e5_textinqf_6epo_0307/20240308004/checkpoint_7.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + # lora_r: 64 # freeze_llm = False + # lora_alpha: 16 + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + + # moe + use_moeqformer: False + use_route_moe: False + + # moe + # use_moeqformer: True + # use_route_moe: False + # moebert_expert_num: 3 + # moebert_route_method: "gate-sentence-post" + # moe_weight_type: "raw_prob" + # moebert_load_balance: 0.05 + # moe_topk: 1 + # use_balance_loss: False + # ln_position: "out" + + # use_moeqformer: True + # use_route_moe: True + # moebert_route_method: "post-route" + # moebert_load_balance: 0.05 + # moebert_expert_num: 3 + # moebert_num_beams: 3 + # moe_weight_type: 'ffn_prob' + # use_balance_loss: False + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_ao_cap_raw_QformerMoE_Post_Route_linear_gate_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_8epo_0309/" + +datasets: + coco_vqa: # 658104 + type: vqa_v2_eval + batch_size: 48 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 3000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/eval/mix_coco_gqa_aokvqa_cap_raw_QformerMoE_lr5e5_textinqf_6epo_0307/" + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["test"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_base_flant5xxl_qformer_blip2.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_base_flant5xxl_qformer_blip2.yaml new file mode 100644 index 0000000..d1d5394 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_base_flant5xxl_qformer_blip2.yaml @@ -0,0 +1,169 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 42 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 42 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 42 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 42 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 42 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 42 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 4 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/base/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_lr5e5_top6layer_textinqf_8epo_0328/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + wandb_log: False + job_name: mix6_base_5e5_flant5_0328 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_blip2_vicuna7b_0317.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_blip2_vicuna7b_0317.yaml new file mode 100644 index 0000000..3268015 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_blip2_vicuna7b_0317.yaml @@ -0,0 +1,176 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'uni_route_moe' + moebert_route_method: "post-route-uni" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 4 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 12 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_12epo_0317/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + wandb_log: True + job_name: mix6_uni_route_post_5e5_3ex3b_005_12ep_0317 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_blip2_vicuna7b_0328.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_blip2_vicuna7b_0328.yaml new file mode 100644 index 0000000..2a3b186 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_blip2_vicuna7b_0328.yaml @@ -0,0 +1,178 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'uni_route_moe' + moebert_route_method: "post-route-uni" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + # sample_ratio: 1 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + # sample_ratio: 3 + sample_ratio: 4 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_8epo_0328/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + wandb_log: False + # job_name: mix6_uni_route_post_5e5_3ex3b_005_10ep_0319 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_flant5xxl_qformer_blip2.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_flant5xxl_qformer_blip2.yaml new file mode 100644 index 0000000..10eddcd --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix6_qformer_moe_route_uni_flant5xxl_qformer_blip2.yaml @@ -0,0 +1,179 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_t5_qformer_moe + model_type: flant5xxl + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2-flant5-xxl/blip2_pretrained_flant5xxl.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + t5_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/google-flan-t5-xxl" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'uni_route_moe' + moebert_route_method: "post-route-uni" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 4 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr3e5_3ex3b_2loss_005_top6layer_textinqf_8epo_0328/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + wandb_log: False + job_name: mix6_qformer_5e5_flant5_0328 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix7_base_vicuna7b_qformer_blip2.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix7_base_vicuna7b_qformer_blip2.yaml new file mode 100644 index 0000000..adaaf91 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix7_base_vicuna7b_qformer_blip2.yaml @@ -0,0 +1,185 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'base' + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 4 + + text_vqa: # train: 34602, val: 5000 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 3 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_textvqa_raw_QformerMoE_Base_lr5e5_textinqf_10epo_0322/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + wandb_log: True + job_name: mix7_base_5e5_10ep_0322 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix7_qformer_moe_route_uni_blip2_vicuna7b_0325.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix7_qformer_moe_route_uni_blip2_vicuna7b_0325.yaml new file mode 100644 index 0000000..cc35e19 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix7_qformer_moe_route_uni_blip2_vicuna7b_0325.yaml @@ -0,0 +1,208 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'uni_route_moe' + moebert_route_method: "post-route-uni" + moebert_load_balance: 0.05 + moebert_expert_num: 4 + moebert_num_beams: 4 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 4 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 4 + + text_vqa: # train: 34602, val: 5000 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 3 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_textvqa_ocrvqa_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_10epo_0322/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + wandb_log: True + job_name: mix8_uni_route_post_5e5_3ex3b_005_10ep_0322 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix_qformer_moe_route_uni_blip2_vicuna7b_0317.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix_qformer_moe_route_uni_blip2_vicuna7b_0317.yaml new file mode 100644 index 0000000..7a77caa --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix_qformer_moe_route_uni_blip2_vicuna7b_0317.yaml @@ -0,0 +1,176 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'uni_route_moe' + moebert_route_method: "post-route-uni" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + + # textcaps_caption: # train: 109765, val: 15830 + # batch_size: 32 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 4 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_10epo_0317/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + wandb_log: True + job_name: mix5_uni_route_post_5e5_3ex3b_005_10ep_0317 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix_qformer_moe_route_uni_blip2_vicuna7b_test_0317.yaml b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix_qformer_moe_route_uni_blip2_vicuna7b_test_0317.yaml new file mode 100644 index 0000000..d240793 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_universal_vicuna/train/mix_qformer_moe_route_uni_blip2_vicuna7b_test_0317.yaml @@ -0,0 +1,176 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'uni_route_moe' + moebert_route_method: "post-route-uni" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + +datasets: + # gqa: # train: 943000, 12578, 12578) + # type: balanced_sft_raw + # batch_size: 32 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + # coco_vqa: # 658104 + # batch_size: 32 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 9 + + # coco_caption: # 414113 train + # batch_size: 32 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 7 + + # aok_vqa: # train: 17056, val: 1145 + # batch_size: 32 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 2 + + # textcaps_caption: # train: 109765, val: 15830 + # batch_size: 32 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 4 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_10epo_0317/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + wandb_log: False + job_name: mix6_post_5e5_3ex3b_005_15ep_0314 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml b/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml new file mode 100644 index 0000000..7ad1693 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml @@ -0,0 +1,154 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/20240314230/checkpoint_12.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + # use_moeqformer: True + # use_route_moe: False + # moebert_expert_num: 3 + # moebert_route_method: "gate-sentence-post" + # moe_weight_type: "raw_prob" + # moebert_load_balance: 0.05 + # moe_topk: 1 + # use_balance_loss: False + # ln_position: "out" + + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/" + +datasets: + # gqa: + # type: balanced_sft_raw_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # coco_vqa: # 658104 + # type: vqa_v2_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + coco_caption: # 414113 train + type: coco_cap_eval + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + # ok_vqa: # train, valid (9009, 5046) + # type: ok_vqa_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # aok_vqa: # train: 17056, val: 1145 + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 3000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/" + + amp: True + resume_ckpt_path: null + + evaluate: True + # test_splits: ["val"] + test_splits: ["test"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_coco_vqa_test.yaml b/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_coco_vqa_test.yaml new file mode 100644 index 0000000..692f04c --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_coco_vqa_test.yaml @@ -0,0 +1,108 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cap_raw_QformerMoE_Post_Route_linear_gate_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_8epo_0309/20240309174/checkpoint_7.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/mix_coco_gqa_aokvqa_cap_raw_QformerMoE_lr5e5_textinqf_6epo_0307/20240308004/checkpoint_7.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + # lora_r: 64 # freeze_llm = False + # lora_alpha: 16 + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + + # moe + use_moeqformer: False + use_route_moe: False + + # moe + # use_moeqformer: True + # use_route_moe: False + # moebert_expert_num: 3 + # moebert_route_method: "gate-sentence-post" + # moe_weight_type: "raw_prob" + # moebert_load_balance: 0.05 + # moe_topk: 1 + # use_balance_loss: False + # ln_position: "out" + + # use_moeqformer: True + # use_route_moe: True + # moebert_route_method: "post-route" + # moebert_load_balance: 0.05 + # moebert_expert_num: 3 + # moebert_num_beams: 3 + # moe_weight_type: 'ffn_prob' + # use_balance_loss: False + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_ao_cap_raw_QformerMoE_Post_Route_linear_gate_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_8epo_0309/" + +datasets: + coco_vqa: # 658104 + type: vqa_v2_eval + batch_size: 48 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 3000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/eval/mix_coco_gqa_aokvqa_cap_raw_QformerMoE_lr5e5_textinqf_6epo_0307/" + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["test"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_copy.yaml b/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_copy.yaml new file mode 100644 index 0000000..eace74d --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_copy.yaml @@ -0,0 +1,154 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/20240314230/checkpoint_10.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + # use_moeqformer: True + # use_route_moe: False + # moebert_expert_num: 3 + # moebert_route_method: "gate-sentence-post" + # moe_weight_type: "raw_prob" + # moebert_load_balance: 0.05 + # moe_topk: 1 + # use_balance_loss: False + # ln_position: "out" + + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314_epoch10/" + +datasets: + # gqa: + # type: balanced_sft_raw_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # coco_vqa: # 658104 + # type: vqa_v2_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + coco_caption: # 414113 train + type: coco_cap_eval + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + # ok_vqa: # train, valid (9009, 5046) + # type: ok_vqa_eval + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # aok_vqa: # train: 17056, val: 1145 + # batch_size: 32 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 3000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314_epoch10/" + + amp: True + resume_ckpt_path: null + + evaluate: True + # test_splits: ["val"] + test_splits: ["test"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_lora.yaml b/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_lora.yaml new file mode 100644 index 0000000..56fa20f --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/eval/mix_vqa_coco_vicuna_eval_lora.yaml @@ -0,0 +1,125 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_post/lora_mix_coco_gqa_aokvqa_raw_QformerMoE_lr5e5_textinqf_6epo_0308/20240309000/checkpoint_7.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + lora_r: 64 # freeze_llm = False + lora_alpha: 16 + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: False + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: False + use_route_moe: False + # use_moeqformer: True + # use_route_moe: True + # moebert_route_method: "post-route" + # moebert_load_balance: 0.05 + # moebert_expert_num: 3 + # moebert_num_beams: 3 + # moe_weight_type: 'ffn_prob' + # use_balance_loss: False + # ln_position: "out" + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/route_save/lora_mix_coco_gqa_ao_cap_raw_QformerMoE_Post_Route_linear_gate_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_8epo_0309/" + +datasets: + # gqa: + # type: balanced_sft_raw_eval + # batch_size: 8 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + coco_vqa: # 658104 + type: vqa_v2_eval + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + # coco_caption: # 414113 train + # type: coco_cap_eval + # batch_size: 8 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # ok_vqa: # train, valid (9009, 5046) + # type: ok_vqa_eval + # batch_size: 8 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + # aok_vqa: # train: 17056, val: 1145 + # batch_size: 8 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" +run: + task: instruction_tuning + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/lora_mix_coco_gqa_aokvqa_raw_QformerMoE_lr5e5_textinqf_6epo_0308/" + num_workers: 4 + + amp: True + resume_ckpt_path: null + + evaluate: True + # test_splits: ["val"] + test_splits: ["test"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance.yaml new file mode 100644 index 0000000..7ae5cbc --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance.yaml @@ -0,0 +1,128 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route-dp" + moebert_load_balance: 0.05 + moebert_expert_num: 2 + moebert_num_beams: 2 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_DP_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_005_5e5lr_top6layer_textinqf_epo8_0121/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance_0122.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance_0122.yaml new file mode 100644 index 0000000..b2cf35b --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance_0122.yaml @@ -0,0 +1,128 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0 + moebert_expert_num: 2 + moebert_num_beams: 2 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_1loss_5e5lr_top6layer_textinqf_epo8_0123/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance_1220.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance_1220.yaml new file mode 100644 index 0000000..8818143 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_balance_1220.yaml @@ -0,0 +1,129 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +# 0107test + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0 + moebert_expert_num: 2 + moebert_num_beams: 2 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/route_save/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/" + +datasets: + # gqa: # train: 943000, 12578, 12578) + # type: balanced_sft_raw + # batch_size: 1 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 1 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + # coco_vqa: # 658104 + # batch_size: 1 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 9 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 5 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_balance_raw_QformerMoE_Route_linear_gate_5ex_2beam_1loss_textinqf_epo5_toplayer3_1209/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco.yaml new file mode 100644 index 0000000..15117b3 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco.yaml @@ -0,0 +1,145 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: False + use_route_moe: False + moebert_route_method: "post-route" + moebert_load_balance: 0 + moebert_expert_num: 2 + moebert_num_beams: 2 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + # output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_vqa_cococap_2024k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_005_5e5lr_top6layer_textinqf_epo8_0122/" + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_vqa_cococap_2024k_raw_QformerMoE_Base_top6layer_textinqf_epo8_0124/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco_0128.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco_0128.yaml new file mode 100644 index 0000000..e1fbc8f --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco_0128.yaml @@ -0,0 +1,145 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0.01 + moebert_expert_num: 2 + moebert_num_beams: 2 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + bal_loss_decay_epoch: 3 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_vqa_cococap_2024k_raw_QformerMoE_Route_Post_ffn_prob_lnout_linear_1gate_2ex_2beam_2loss_001_loss_decay_5e5lr_top6layer_textinqf_epo8_0129/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco_0310_lora.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco_0310_lora.yaml new file mode 100644 index 0000000..8570a81 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_coco_0310_lora.yaml @@ -0,0 +1,164 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + lora_r: 64 # freeze_llm = False + lora_alpha: 16 + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: False + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 15 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/lora_mix_coco_gqa_ao_cap_raw_QformerMoE_Post_Route_linear_gate_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0310/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + wandb_log: True + job_name: mix_post_5e5_3ex3b_005_15ep + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix.yaml new file mode 100644 index 0000000..16b3ef5 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix.yaml @@ -0,0 +1,188 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0.05 + moebert_expert_num: 2 + moebert_num_beams: 2 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + +datasets: + gqa: + type: balanced_sft_raw_eval + batch_size: 16 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 15 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 12 + + ocrvqa: # train 207572 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 30 + + llava_reason: # 76643 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 80 + + llava_conversation: # 56681 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 30 + + llava_detail: # 23240 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 20 + + coco_caption: # 414113 train + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 10 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_1048k_raw_QformerMoE_Route_Post_ffn_prob_linear_1gate_2ex_2beam_2loss_5e5lr_top6layer_textinqf_epo8_0118/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix6_0314.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix6_0314.yaml new file mode 100644 index 0000000..c8603bf --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix6_0314.yaml @@ -0,0 +1,178 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + ln_position: "out" + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 4 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 15 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + wandb_log: True + job_name: mix6_post_5e5_3ex3b_005_15ep_0314 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix6_0314_base.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix6_0314_base.yaml new file mode 100644 index 0000000..25c5a91 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_data_mix6_0314_base.yaml @@ -0,0 +1,170 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + general_version: 'base' + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + + coco_caption: # 414113 train + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 7 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 2 + + textcaps_caption: # train: 109765, val: 15830 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 4 + +run: + task: instruction_tuning + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 10 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_tcap_raw_Qformer_base_lr5e5_10epo_0320_instruct/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + wandb_log: True + job_name: mix6_base_12ep_0316 + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_pretrain.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_pretrain.yaml new file mode 100644 index 0000000..877f53e --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_blip2_vicuna7b_pretrain.yaml @@ -0,0 +1,82 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_pretrained/blip2_pretrained.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_pretrained/blip2_pretrained.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 8 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_expert_num: 3 + moebert_num_beams: 3 + moebert_route_method: "post-route" + +datasets: + llava_pretrain: + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + + +run: + task: image_text_pretrain + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 1e-4 + min_lr: 8e-5 + warmup_lr: 1e-6 + + weight_decay: 0.05 + max_epoch: 4 + num_workers: 4 + warmup_steps: 5000 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route_pretrain/llava_pretrain_595k_Qformer_MoE_Route_Post_5ksteps_4e_3ex3b_42seed_1216/" + + amp: True + resume_ckpt_path: null + + train_splits: ["train"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_dp_blip2_vicuna7b_data_balance.yaml b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_dp_blip2_vicuna7b_data_balance.yaml new file mode 100644 index 0000000..be65c17 --- /dev/null +++ b/minigpt4/projects/qformer_moe_route_vicuna/train/mix_qformer_moe_route_dp_blip2_vicuna7b_data_balance.yaml @@ -0,0 +1,128 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route-dp" + moebert_load_balance: 0.05 + moebert_expert_num: 2 + moebert_num_beams: 2 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 10 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 1 + + coco_vqa: # 658104 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 9 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_DP_Route_Post_ffn_prob_linear_1gate_2ex_2beam_2loss_5e5lr_top6layer_textinqf_epo8_0118/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml b/minigpt4/projects/qformer_moe_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml new file mode 100644 index 0000000..27567d0 --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml @@ -0,0 +1,110 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo3_1204/20231204165/checkpoint_best.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + moebert_expert_num: 5 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0 + moe_topk: 1 + use_balance_loss: False + moe_weight_type: 'l2_norm' + gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gate_save/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo3_1204/" + +datasets: + gqa: + type: balanced_sft_raw_eval + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + ok_vqa: # train, valid (9009, 5046) + type: ok_vqa_eval + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + coco_vqa: # 658104 + type: vqa_v2_eval + batch_size: 32 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 5 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/eval/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo3_1204/mix/" + + amp: True + resume_ckpt_path: null + + evaluate: True + test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True + diff --git a/minigpt4/projects/qformer_moe_vicuna/eval/vqa_benchmark_evaluation.yaml b/minigpt4/projects/qformer_moe_vicuna/eval/vqa_benchmark_evaluation.yaml new file mode 100644 index 0000000..98de298 --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/eval/vqa_benchmark_evaluation.yaml @@ -0,0 +1,90 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: True + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + finetuned: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_1loss_5e5lr_top6layer_textinqf_epo8_0111/20240111145/checkpoint_best.pth" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # T5 + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0 + moebert_expert_num: 2 + moebert_num_beams: 2 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + +datasets: + ok_vqa: # train, valid (9009, 5046) + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + +evaluation_datasets: + vizwiz: + eval_file_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VizWiz/val.json + img_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/VizWiz/val + max_new_tokens: 20 + batch_size: 10 + # iconvqa: + # eval_file_path: /path/to/eval/annotation/path + # img_path: /path/to/eval/image/path + # max_new_tokens: 20 + # batch_size: 10 + # vsr: + # eval_file_path: cambridgeltl/vsr_zeroshot + # img_path: /path/to/eval/image/path + # max_new_tokens: 20 + # batch_size: 10 + hm: + eval_file_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/hm_data/dev_seen.jsonl + img_path: /mnt/pfs-guan-ssai/nlu/wanghanzi/data/hm_data/ + max_new_tokens: 20 + batch_size: 100 + +run: + task: instruction_tuning + name: vqa_benchmark_evaluation + save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/eval/benchmarks/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_2ex_2beam_1gate_1loss_5e5lr_top6layer_textinqf_epo8_0111/" + seed: 42 + + + + + diff --git a/minigpt4/projects/qformer_moe_vicuna/train/gqa_qformer_moe_blip2_vicuna7b.yaml b/minigpt4/projects/qformer_moe_vicuna/train/gqa_qformer_moe_blip2_vicuna7b.yaml new file mode 100644 index 0000000..c6945d7 --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/train/gqa_qformer_moe_blip2_vicuna7b.yaml @@ -0,0 +1,90 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna 7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + moebert_expert_num: 5 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 2 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gqa_943k_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top2_3loss_textinqf_epo3_1119/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_vicuna/train/gqa_raw_qformer_blip2_vicuna7b.yaml b/minigpt4/projects/qformer_moe_vicuna/train/gqa_raw_qformer_blip2_vicuna7b.yaml new file mode 100644 index 0000000..0ad217f --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/train/gqa_raw_qformer_blip2_vicuna7b.yaml @@ -0,0 +1,90 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: False + moebert_expert_num: 1 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 1 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gqa_943k_raw_224_vicuna7b_Qformer_train_qf_train_qt_textinqf_epo3_1119/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_vicuna/train/llava_mix_qformer_moe_blip2_vicuna7b.yaml b/minigpt4/projects/qformer_moe_vicuna/train/llava_mix_qformer_moe_blip2_vicuna7b.yaml new file mode 100644 index 0000000..9b655ed --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/train/llava_mix_qformer_moe_blip2_vicuna7b.yaml @@ -0,0 +1,125 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + moebert_expert_num: 5 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.00 + moe_topk: 1 + use_balance_loss: False + +datasets: + llava_mix: # train: + type: mix_coco_gqa + batch_size: 8 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + + gqa: + type: balanced_sft_raw_eval + batch_size: 8 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + ok_vqa: + type: ok_vqa_eval + batch_size: 8 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + + coco_vqa: # 658104 + type: vqa_v2_eval + batch_size: 8 + vis_processor: + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + eval: + name: "blip_caption" + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_together_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_textinqf_epo3_1121/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_vicuna/train/mix_coco_llava_qformer_moe_blip2_vicuna7b.yaml b/minigpt4/projects/qformer_moe_vicuna/train/mix_coco_llava_qformer_moe_blip2_vicuna7b.yaml new file mode 100644 index 0000000..1d40f30 --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/train/mix_coco_llava_qformer_moe_blip2_vicuna7b.yaml @@ -0,0 +1,174 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_expert_num: 5 + moebert_num_beams: 2 + +datasets: + # gqa: + # type: balanced_sft_raw_eval + # batch_size: 8 + # vis_processor: + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # eval: + # name: "blip_caption" + + ok_vqa: # train, valid (9009, 5046) + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + coco_vqa: # 658104 + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 15 + + aok_vqa: # train: 17056, val: 1145 + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 12 + + ocrvqa: # train 207572 + batch_size: 6 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 30 + + llava_reason: # 76643 + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 80 + + llava_conversation: # 56681 + batch_size: 2 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 30 + + llava_detail: # 23240 + batch_size: 4 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + text_processor: + train: + name: "blip_caption" + sample_ratio: 20 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 6 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_1048k_raw_QformerMoE_Route_Post_NoNorm_5ex__1loss_textinqf_epo3_seed40_1127/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b.yaml b/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b.yaml new file mode 100644 index 0000000..7e1245e --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b.yaml @@ -0,0 +1,126 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + moebert_expert_num: 5 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0 + moe_topk: 1 + use_balance_loss: False + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 50 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + coco_vqa: # 658104 + batch_size: 64 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 3 + num_workers: 4 + warmup_steps: 600 + + # seed: 42 + seed: 40 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_1610k_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_epo3_seed40_1127/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_3ex3beam_0112.yaml b/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_3ex3beam_0112.yaml new file mode 100644 index 0000000..979e0a1 --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_3ex3beam_0112.yaml @@ -0,0 +1,131 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + use_route_moe: True + moebert_route_method: "post-route" + moebert_load_balance: 0.05 + moebert_expert_num: 3 + moebert_num_beams: 3 + moe_weight_type: 'ffn_prob' + use_balance_loss: False + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + # batch_size: 16 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 50 + + ok_vqa: # train, valid (9009, 5046) + # batch_size: 16 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + coco_vqa: # 658104 + # batch_size: 16 + batch_size: 32 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + # init_lr: 2e-5 + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_1610k_raw_QformerMoE_Route_Post_ffn_prob_linear_gate_3ex_3beam_1gate_2loss_5e5lr_top6layer_textinqf_epo8_0112/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_balance.yaml b/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_balance.yaml new file mode 100644 index 0000000..d3f21ec --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_balance.yaml @@ -0,0 +1,130 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: True + moebert_expert_num: 3 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0 + moe_topk: 1 + use_balance_loss: False + moe_weight_type: 'raw_prob' + # gate_save_path: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/gate_save/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1206/" + + +datasets: + gqa: # train: 94254 + type: balanced_sft_raw_part + batch_size: 1 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 50 + + ok_vqa: # train, valid (9009, 5046 + batch_size: 1 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + # coco_vqa: # 214352 vqa_val + # type: vqa_v2_part + # batch_size: 1 + # vis_processor: + # train: + # name: "blip2_image_train" + # image_size: 224 + # eval: + # name: "blip2_image_eval" + # image_size: 224 + # text_processor: + # train: + # name: "blip_caption" + # eval: + # name: "blip_caption" + # sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + init_lr: 2e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 1 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 1000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_balance_raw_QformerMoE_train_qf_train_qt_linear_gate_5ex_top1_1loss_textinqf_training_epo5_toplayer3_1220_test/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + # test_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_raw_0112.yaml b/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_raw_0112.yaml new file mode 100644 index 0000000..afdb4eb --- /dev/null +++ b/minigpt4/projects/qformer_moe_vicuna/train/mix_qformer_moe_blip2_vicuna7b_data_raw_0112.yaml @@ -0,0 +1,125 @@ + # Copyright (c) 2022, salesforce.com, inc. + # All rights reserved. + # SPDX-License-Identifier: BSD-3-Clause + # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause + +model: + arch: blip2_vicuna_instruct + model_type: vicuna7b_pretrain + load_pretrained: True + load_finetuned: False + vit_model: eva_clip_g + pretrained: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + # finetuned: "" + q_former_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/blip2/blip2_vicuna7b/blip2_pretrained_vicuna7b.pth" + + # vit encoder + image_size: 224 + drop_path_rate: 0 + use_grad_checkpoint: False + vit_precision: "fp16" + + # Q-Former + num_query_token: 32 + qformer_text_input: True + + # vicuna7b + llm_model: "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/vicuna-7b-v1.1" + prompt: "" + max_txt_len: 256 + max_output_txt_len: 256 + + # freeze + freeze_vit: True + freeze_llm: True + freeze_qformer: False + freeze_t5_proj: False + + # moe + use_moeqformer: False + moebert_expert_num: 1 + moebert_route_method: "gate-sentence" + moebert_load_balance: 0.05 + moe_topk: 1 + +datasets: + gqa: # train: 943000, 12578, 12578) + type: balanced_sft_raw + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 50 + + ok_vqa: # train, valid (9009, 5046) + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 8 + + coco_vqa: # 658104 + batch_size: 16 + vis_processor: + train: + name: "blip2_image_train" + image_size: 224 + eval: + name: "blip2_image_eval" + image_size: 224 + text_processor: + train: + name: "blip_caption" + eval: + name: "blip_caption" + sample_ratio: 15 + +run: + task: instruction_tuning + # optimizer + lr_sched: "linear_warmup_cosine_lr" + # init_lr: 2e-5 + init_lr: 5e-5 + min_lr: 1e-6 + warmup_lr: 1e-6 + log_freq: 5 + save_freq: 1500 + + weight_decay: 0.05 + max_epoch: 8 + num_workers: 4 + warmup_steps: 600 + iters_per_epoch: 5000 + + seed: 42 + output_dir: "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe/mix_coco_gqa_1610k_raw_QformerMoE_train_qf_train_qt_1ex_top1_textinqf_epo8_lr5e5_seed42_0112/" + + amp: True + resume_ckpt_path: null + + evaluate: False + train_splits: ["train"] + valid_splits: ["val"] + + device: "cuda" + world_size: 1 + dist_url: "env://" + distributed: True \ No newline at end of file diff --git a/minigpt4/runners/__init__.py b/minigpt4/runners/__init__.py index 64e7a4d..ab0b676 100644 --- a/minigpt4/runners/__init__.py +++ b/minigpt4/runners/__init__.py @@ -6,5 +6,6 @@ """ from minigpt4.runners.runner_base import RunnerBase +from minigpt4.runners.runner_iter import RunnerIter -__all__ = ["RunnerBase"] +__all__ = ["RunnerBase", "RunnerIter"] \ No newline at end of file diff --git a/minigpt4/runners/runner_base.py b/minigpt4/runners/runner_base.py index bc8dc1d..831d07b 100644 --- a/minigpt4/runners/runner_base.py +++ b/minigpt4/runners/runner_base.py @@ -110,6 +110,7 @@ class RunnerBase: else: p_wd.append(p) num_parameters += p.data.nelement() + # import pdb; pdb.set_trace() # 0107test logging.info("number of trainable parameters: %d" % num_parameters) optim_params = [ { @@ -191,7 +192,7 @@ class RunnerBase: If train_dataset_ratio is provided, create a MultiIterLoader to sample each dataset by ratios during training. - Currently do not support multiple datasets for validation and test. + Currently do not support multiple datasets for validation and test.(Han Done) Returns: dict: {split_name: (tuples of) dataloader} @@ -208,9 +209,12 @@ class RunnerBase: batch_sizes = {dataset_name: getattr(self.config.datasets_cfg, dataset_name).batch_size for dataset_name in self.datasets.keys()} + print(batch_sizes) + datasets, batch_sizes = reorg_datasets_by_split(self.datasets, batch_sizes) self.datasets = datasets # self.datasets = concat_datasets(datasets) + print(self.datasets.keys()) # dict_keys(['train', 'val', 'test']) # print dataset statistics after concatenation/chaining for split_name in self.datasets: @@ -252,7 +256,9 @@ class RunnerBase: batch_sizes = [batch_sizes[split] for split in split_names] is_trains = [split in self.train_splits for split in split_names] - print("batch sizes", batch_sizes) + print("split_names: ",split_names) + print("is_trains: ",is_trains) + print("batch sizes: ", batch_sizes) collate_fns = [] for dataset in datasets: @@ -381,6 +387,8 @@ class RunnerBase: if len(self.valid_splits) > 0: for split_name in self.valid_splits: logging.info("Evaluating on {}.".format(split_name)) + + self._save_checkpoint(cur_epoch, is_best=False) val_log = self.eval_epoch( split_name=split_name, cur_epoch=cur_epoch @@ -396,7 +404,7 @@ class RunnerBase: best_epoch, best_agg_metric = cur_epoch, agg_metrics self._save_checkpoint(cur_epoch, is_best=True) - + val_log.update({"best_epoch": best_epoch}) self.log_stats(val_log, split_name) @@ -459,7 +467,7 @@ class RunnerBase: During testing, we will use provided weights and skip reloading the best checkpoint . """ data_loader = self.dataloaders.get(split_name, None) - assert data_loader, "data_loader for split {} is None.".format(split_name) + # assert data_loader, "data_loader for split {} is None.".format(split_name) # TODO In validation, you need to compute loss as well as metrics # TODO consider moving to model.before_evaluation() @@ -550,17 +558,26 @@ class RunnerBase: return loader loaders = [] - - for dataset, bsz, is_train, collate_fn in zip( + for dataset, bsz, is_train, collate_fn in zip( # 分别遍历 test valid train datasets, batch_sizes, is_trains, collate_fns ): + print("dataset:",dataset) + print("bsz:",bsz) + print("is_train:",is_train) + print("collate_fn:",collate_fn) + print("len(dataset[0]):",len(dataset[0])) + if not is_train: + dataset_ratios = None + if isinstance(dataset, list) or isinstance(dataset, tuple): if hasattr(dataset[0], 'sample_ratio') and dataset_ratios is None: dataset_ratios = [d.sample_ratio for d in dataset] + print("dataset_ratios:", dataset_ratios) + loader = MultiIterLoader( loaders=[ _create_loader(d, num_workers, bsz[i], is_train, collate_fn[i]) - for i, d in enumerate(dataset) + for i, d in enumerate(dataset) # dataset相当于list(Dataset) 遍历里面不同类型的dataset [OKVQADataset, COCOVQADataset] ], ratios=dataset_ratios, ) diff --git a/minigpt4/runners/runner_iter.py b/minigpt4/runners/runner_iter.py new file mode 100644 index 0000000..b82488a --- /dev/null +++ b/minigpt4/runners/runner_iter.py @@ -0,0 +1,332 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import datetime +import logging +import os +import time + +import torch +import torch.distributed as dist +import webdataset as wds +from minigpt4.common.dist_utils import download_cached_file, is_main_process, main_process +from minigpt4.common.registry import registry +from minigpt4.common.utils import is_url +from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split +from minigpt4.runners.runner_base import RunnerBase +from torch.utils.data.dataset import ChainDataset + + +@registry.register_runner("runner_iter") +class RunnerIter(RunnerBase): + """ + Run training based on the number of iterations. This is common when + the training dataset size is large. Underhood logic is similar to + epoch-based training by considering every #iters_per_inner_epoch as an + inner epoch. + + In iter-based runner, after every #iters_per_inner_epoch steps, we + + 1) do a validation epoch; + 2) schedule the learning rate; + 3) save the checkpoint. + + We refer every #iters_per_inner_epoch steps as an inner epoch. + """ + + def __init__(self, cfg, task, model, datasets, job_id): + super().__init__(cfg, task, model, datasets, job_id) + + self.start_iters = 0 + + self.max_iters = int(self.config.run_cfg.get("max_iters", -1)) + assert self.max_iters > 0, "max_iters must be greater than 0." + + self.iters_per_inner_epoch = int( + self.config.run_cfg.get("iters_per_inner_epoch", -1) + ) + assert ( + self.iters_per_inner_epoch > 0 + ), "iters_per_inner_epoch must be greater than 0." + + @property + def max_epoch(self): + return int(self.max_iters / self.iters_per_inner_epoch) + + @property + def cur_epoch(self): + try: + return self.train_loader.epoch + except AttributeError: + # pipeline data (e.g. LAION) is streaming, have no concept of epoch + return 0 + + def _progress(self, cur_iters): + return "{}_iters={}".format(self.cur_epoch, cur_iters) + + def train(self): + start_time = time.time() + best_agg_metric = 0 + best_iters = 0 + + self.log_config() + + # resume from checkpoint if specified + if not self.evaluate_only and self.resume_ckpt_path is not None: + self._load_checkpoint(self.resume_ckpt_path) + + for start_iters in range( + self.start_iters, self.max_iters, self.iters_per_inner_epoch + ): + end_iters = start_iters + self.iters_per_inner_epoch + + # training phase + if not self.evaluate_only: + logging.info( + "Start training, max_iters={}, in total {} inner epochs.".format( + self.max_iters, int(self.max_iters / self.iters_per_inner_epoch) + ) + ) + if start_iters == self.start_iters: + self.task.before_training( + model=self.unwrap_dist_model(self.model), + dataset=self.datasets, + ) + train_stats = self.train_iters(self.cur_epoch, start_iters) + self.log_stats(split_name="train", stats=train_stats) + + # evaluation phase + if len(self.valid_splits) > 0: + for split_name in self.valid_splits: + logging.info("Evaluating on {}.".format(split_name)) + + val_log = self.eval_epoch( + split_name=split_name, cur_epoch=self._progress(end_iters) + ) + if val_log is not None: + if is_main_process(): + assert ( + "agg_metrics" in val_log + ), "No agg_metrics found in validation log." + + agg_metrics = val_log["agg_metrics"] + if agg_metrics > best_agg_metric and split_name == "val": + best_iters, best_agg_metric = end_iters, agg_metrics + + self._save_checkpoint(end_iters, is_best=True) + + val_log.update({"best_iters": best_iters}) + self.log_stats(val_log, split_name) + + else: + # if no validation split is provided, we just save the checkpoint at the end of each inner epoch. + if not self.evaluate_only: + self._save_checkpoint(end_iters, is_best=False) + + if self.evaluate_only: + break + dist.barrier() + + # testing phase + self.evaluate(cur_epoch=self.cur_epoch) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logging.info("Training time {}".format(total_time_str)) + + def train_iters(self, epoch, start_iters): + # train by iterations + self.model.train() + + return self.task.train_iters( + epoch=epoch, + start_iters=start_iters, + iters_per_inner_epoch=self.iters_per_inner_epoch, + model=self.model, + data_loader=self.train_loader, + optimizer=self.optimizer, + scaler=self.scaler, + lr_scheduler=self.lr_scheduler, + cuda_enabled=self.cuda_enabled, + log_freq=self.log_freq, + accum_grad_iters=self.accum_grad_iters, + ) + + @main_process + def _save_checkpoint(self, cur_iters, is_best=False): + model_no_ddp = self.unwrap_dist_model(self.model) + param_grad_dic = { + k: v.requires_grad for (k, v) in model_no_ddp.named_parameters() + } + + state_dict = model_no_ddp.state_dict() + for k in list(state_dict.keys()): + if k in param_grad_dic.keys() and not param_grad_dic[k]: + # delete parameters that do not require gradient + del state_dict[k] + + save_obj = { + "model": state_dict, + "optimizer": self.optimizer.state_dict(), + "config": self.config.to_dict(), + "scaler": self.scaler.state_dict() if self.scaler else None, + "iters": cur_iters, + } + save_to = os.path.join( + self.output_dir, + "checkpoint_{}.pth".format("best" if is_best else cur_iters), + ) + logging.info("Saving checkpoint at iters {} to {}.".format(cur_iters, save_to)) + torch.save(save_obj, save_to) + + def _load_checkpoint(self, url_or_filename): + """ + Resume from a checkpoint. + """ + if is_url(url_or_filename): + cached_file = download_cached_file( + url_or_filename, check_hash=False, progress=True + ) + checkpoint = torch.load(cached_file, map_location=self.device) + elif os.path.isfile(url_or_filename): + checkpoint = torch.load(url_or_filename, map_location=self.device) + else: + raise RuntimeError("checkpoint url or path is invalid") + + state_dict = checkpoint["model"] + self.unwrap_dist_model(self.model).load_state_dict(state_dict) + + self.optimizer.load_state_dict(checkpoint["optimizer"]) + if self.scaler and "scaler" in checkpoint: + self.scaler.load_state_dict(checkpoint["scaler"]) + + self.start_iters = checkpoint["iters"] + 1 + logging.info("Resume checkpoint from {}".format(url_or_filename)) + + @property + def dataloaders(self) -> dict: + """ + A property to get and create dataloaders by split just in need. + + If no train_dataset_ratio is provided, concatenate map-style datasets and + chain wds.DataPipe datasets separately. Training set becomes a tuple + (ConcatDataset, ChainDataset), both are optional but at least one of them is + required. The resultant ConcatDataset and ChainDataset will be sampled evenly. + + If train_dataset_ratio is provided, create a MultiIterLoader to sample + each dataset by ratios during training. + + Currently do not support multiple datasets for validation and test. + + Returns: + dict: {split_name: (tuples of) dataloader} + """ + if self._dataloaders is None: + # reoganize datasets by split and concatenate/chain if necessary + dataset_ratios = self.config.run_cfg.get("train_dataset_ratios", None) + + if dataset_ratios is None: + # concatenate map-style datasets and chain wds.DataPipe datasets separately + # training set becomes a tuple (ConcatDataset, ChainDataset), both are + # optional but at least one of them is required. The resultant ConcatDataset + # and ChainDataset will be sampled evenly. + logging.info( + "dataset_ratios not specified, datasets will be concatenated (map-style datasets) or chained (webdataset.DataPipeline)." + ) + + datasets = reorg_datasets_by_split(self.datasets) + self.datasets = concat_datasets(datasets) + else: + # create multi-loader with the provided ratios, without concatenating or chaining + missing_keys = [k for k in dataset_ratios if k not in self.datasets] + if len(missing_keys) > 0: + raise ValueError( + "Datasets with the following split names are not found: {}".format( + missing_keys + ) + ) + + unexpected_keys = [k for k in self.datasets if k not in dataset_ratios] + if len(unexpected_keys) > 0: + raise ValueError( + "Datasets with the following split names are not expected: {}".format( + unexpected_keys + ) + ) + + dataset_ratios = [float(dataset_ratios[k]) for k in self.datasets] + self.datasets = reorg_datasets_by_split(self.datasets) + # to keep the same structure as return value of concat_datasets + self.datasets = { + k: v[0] if len(v) == 1 else v for k, v in datasets.items() + } + + # print dataset statistics after concatenation/chaining + for split_name in self.datasets: + if isinstance(self.datasets[split_name], tuple) or isinstance( + self.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [wds.DataPipeline, ChainDataset] + else 0 + for d in self.datasets[split_name] + ] + ) + + else: + try: + # a single map-style dataset + num_records = len(self.datasets[split_name]) + except TypeError: + # a single wds.DataPipeline or ChainDataset + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + + # create dataloaders + split_names = sorted(self.datasets.keys()) + + datasets = [self.datasets[split] for split in split_names] + is_trains = [split in self.train_splits for split in split_names] + + batch_sizes = [ + self.config.run_cfg.batch_size_train + if split == "train" + else self.config.run_cfg.batch_size_eval + for split in split_names + ] + + collate_fns = [] + for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + + dataloaders = self.create_loaders( + datasets=datasets, + num_workers=self.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, + dataset_ratios=dataset_ratios, + ) + + self._dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + return self._dataloaders diff --git a/minigpt4/tasks/__init__.py b/minigpt4/tasks/__init__.py index ab1fb1c..f8f30af 100644 --- a/minigpt4/tasks/__init__.py +++ b/minigpt4/tasks/__init__.py @@ -8,7 +8,7 @@ from minigpt4.common.registry import registry from minigpt4.tasks.base_task import BaseTask from minigpt4.tasks.image_text_pretrain import ImageTextPretrainTask - +from minigpt4.tasks.instruction_tuning import InstructionTask def setup_task(cfg): assert "task" in cfg.run_cfg, "Task name must be provided." @@ -23,4 +23,5 @@ def setup_task(cfg): __all__ = [ "BaseTask", "ImageTextPretrainTask", + "InstructionTask", ] diff --git a/minigpt4/tasks/base_task.py b/minigpt4/tasks/base_task.py index 1cfa46c..7aa60d2 100644 --- a/minigpt4/tasks/base_task.py +++ b/minigpt4/tasks/base_task.py @@ -10,10 +10,11 @@ import os import torch import torch.distributed as dist -from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized +from minigpt4.common.dist_utils import get_rank, get_world_size, is_main_process, is_dist_avail_and_initialized, main_process from minigpt4.common.logger import MetricLogger, SmoothedValue from minigpt4.common.registry import registry from minigpt4.datasets.data_utils import prepare_sample +from torch.utils.tensorboard import SummaryWriter import wandb class BaseTask: @@ -34,6 +35,20 @@ class BaseTask: model_cls = registry.get_model_class(model_config.arch) return model_cls.from_config(model_config) + def build_tensorboard(self, cfg): + """ + Build a tensorboard monitoring the global training process. + """ + setattr(self, 'writer', None) + if is_main_process(): + writer = SummaryWriter(log_dir=cfg.run_cfg.output_dir) + setattr(self, 'writer', writer) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return None + def build_datasets(self, cfg): """ Build a dictionary of datasets, keyed by split 'train', 'valid', 'test'. @@ -57,8 +72,12 @@ class BaseTask: builder = registry.get_builder_class(name)(dataset_config) dataset = builder.build_datasets() + try: + dataset['train'].name = name + except Exception as e: + print(e,'\n No train dataset') + dataset['val'].name = name - dataset['train'].name = name if 'sample_ratio' in dataset_config: dataset['train'].sample_ratio = dataset_config.sample_ratio @@ -219,13 +238,17 @@ class BaseTask: with torch.cuda.amp.autocast(enabled=use_amp): loss = self.train_step(model=model, samples=samples) - + # after_train_step() if use_amp: + # torch.autograd.set_detect_anomaly(True) + # 反向传播时检测是否有异常值,定位code + # with torch.autograd.detect_anomaly(): scaler.scale(loss).backward() else: loss.backward() + # import pdb; pdb.set_trace() # 0107test # update gradients every accum_grad_iters iterations if (i + 1) % accum_grad_iters == 0: if use_amp: @@ -233,13 +256,24 @@ class BaseTask: scaler.update() else: optimizer.step() + + # import pdb; pdb.set_trace()# 0107test + optimizer.zero_grad() - # if self.cfg.wandb_log: - if self.cfg.run_cfg.wandb_log: - wandb.log({"epoch": inner_epoch, "loss": loss}) + if self.cfg.run_cfg.wandb_log and i%100==0: + wandb.log({"epoch": inner_epoch, "loss": loss.item()}) + + if self.cfg.run_cfg.wandb_log and i%10==0: + source = samples['source'][0] + wandb.log({f"{source}_loss": loss.item()}) + + metric_logger.update(loss=loss.item()) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) - + + if (i + 1) % log_freq == 0: + self.update_writer(metric_logger, i+1) + # after train_epoch() # gather the stats from all processes metric_logger.synchronize_between_processes() @@ -288,3 +322,9 @@ class BaseTask: print("result file saved to %s" % final_result_file) return final_result_file + + @main_process + def update_writer(self, metric_logger, iter_num): + for name, meter in metric_logger.meters.items(): + # meter: Instance of class SmoothedValue + self.writer.add_scalar(name, float(str(meter)), iter_num) diff --git a/minigpt4/tasks/caption_utils.py b/minigpt4/tasks/caption_utils.py new file mode 100644 index 0000000..d3a514a --- /dev/null +++ b/minigpt4/tasks/caption_utils.py @@ -0,0 +1,67 @@ + +from collections import defaultdict +from pycocoevalcap.eval import COCOEvalCap +class COCO_Annotation: + def __init__(self, annotation_file): + self.coco_cn_file = annotation_file + self.imgToAnns = self.build_imgToAnns() + + def build_imgToAnns(self): + imgToAnns = defaultdict(list) + with open(self.coco_cn_file, "r", encoding="UTF-8") as fin: + for line in fin: + line = line.strip() + temp = eval(line) + annotations = temp['annotations'] + for ann in annotations: + image_id = str(ann['image_id']).zfill(6) + imgToAnns[image_id].append({'image_id':image_id,'caption':ann['caption'],'image': ann['image_id']}) + return imgToAnns + + def getImgIds(self): + return self.imgToAnns.keys() + +class COCO_Result: + def __init__(self,result_file): + self.coco_cn_file = result_file + self.imgToAnns = self.build_imgToAnns() + + def build_imgToAnns(self): + imgToAnns = dict() + data = json.load(open(self.coco_cn_file, "r")) + for d in data: + tmp = { + 'image_id':d['question_id'][-6:], + 'caption':d['answer'] + } + imgToAnns[d['question_id'][-6:]] = [tmp] + return imgToAnns + +def coco_caption_eval(coco_gt_root, results_file, split_name): + files = { + "val":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_val_gt.json", + "test":"/mnt/pfs-guan-ssai/nlu/wanghanzi/data/COCO_Cap/coco_karpathy_test_gt.json" + } + + # create coco object and coco_result object + annotation_file = files[split_name] + coco = COCO_Annotation(annotation_file) + coco_result = COCO_Result(results_file) + + # create coco_eval object by taking coco and coco_result + coco_eval = COCOEvalCap(coco, coco_result) + + # evaluate on a subset of images by setting + # coco_eval.params['image_id'] = coco_result.getImgIds() + # please remove this line when evaluating the full validation set + # coco_eval.params['image_id'] = coco_result.getImgIds() + + # evaluate results + # SPICE will take a few minutes the first time, but speeds up due to caching + coco_eval.evaluate() + + # print output evaluation scores + for metric, score in coco_eval.eval.items(): + print(f"{metric}: {score:.3f}") + + return coco_eval \ No newline at end of file diff --git a/minigpt4/tasks/instruction_tuning.py b/minigpt4/tasks/instruction_tuning.py new file mode 100644 index 0000000..7fc7e9d --- /dev/null +++ b/minigpt4/tasks/instruction_tuning.py @@ -0,0 +1,316 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import logging +import json +import os +import torch.distributed as dist +from collections import defaultdict +from minigpt4.common.registry import registry +from minigpt4.tasks.base_task import BaseTask +import minigpt4.common.dist_utils as dist_utils +from minigpt4.common.logger import MetricLogger +from minigpt4.datasets.data_utils import prepare_sample +from minigpt4.common.dist_utils import is_dist_avail_and_initialized +from minigpt4.common.vqa_tools.vqa import VQA +from minigpt4.common.vqa_tools.vqa_eval import VQAEval +from minigpt4.common.caption_tools.caption_utils import coco_caption_eval, textcaps_caption_eval + + +@registry.register_task("instruction_tuning") +class InstructionTask(BaseTask): + def __init__( + self, + num_beams, + max_len, + min_len, + evaluate, + num_ans_candidates, + inference_method="rank", + prompt="", + ): + super().__init__() + + self.num_beams = num_beams + self.max_len = max_len + self.min_len = min_len + + self.evaluate = evaluate + self.inference_method = inference_method + self.num_ans_candidates = num_ans_candidates + self.prompt = prompt + + self.answer_list = None + + self.ques_files = defaultdict(dict) + self.anno_files = defaultdict(dict) + + @classmethod + def setup_task(cls, cfg): + run_cfg = cfg.run_cfg + + num_beams = run_cfg.get("num_beams", 3) + max_len = run_cfg.get("max_len", 30) + min_len = run_cfg.get("min_len", 1) + + evaluate = run_cfg.get("evaluate", False) + + inference_method = run_cfg.get("inference_method", "rank") + num_ans_candidates = run_cfg.get("num_ans_candidates", 128) + prompt = run_cfg.get("prompt", "") + + return cls( + num_beams=num_beams, + max_len=max_len, + min_len=min_len, + evaluate=evaluate, + num_ans_candidates=num_ans_candidates, + inference_method=inference_method, + prompt=prompt, + ) + + def build_datasets(self, cfg): + datasets = super().build_datasets(cfg) + + # get question file, annotation file and anwser list in COCO format + for dataset in datasets.values(): + for split in dataset: + source = dataset[split].source + if ( + hasattr(dataset[split], "coco_fmt_qust_file") + and dataset[split].coco_fmt_qust_file is not None + ): + self.ques_files[split][source] = dataset[split].coco_fmt_qust_file + self.anno_files[split][source] = dataset[split].coco_fmt_anno_file + + # try: + # self.answer_list = dataset[split].answer_list + # except AttributeError: + # # if answer_list is not provided, then set it to None + # pass + + if len(self.ques_files) > 0: + assert len(self.ques_files) == len( + self.anno_files + ), "Only support one split for evaluation." + + return datasets + + def valid_step(self, model, samples): + answers = model.predict_answers( + samples=samples, + answer_list=self.answer_list, + inference_method=self.inference_method, + num_beams=self.num_beams, + max_len=self.max_len, + min_len=self.min_len, + num_ans_candidates=self.num_ans_candidates, + prompt=self.prompt, + ) + pred_qa_pairs = [] + + text_inputs = samples["text_input"] + + sources = samples["source"] + source = samples["source"][0] + + if source in ['vqav2','okvqa','gqa']: + sample_ids = [int(sample_id.item()) for sample_id in samples["question_id"]] + elif source in ['aokvqa']: + sample_ids = [sample_id for sample_id in samples["question_id"]] + elif source in ['coco_cap', 'text_cap', 'text_vqa']: + sample_ids = samples["image_id"] + + # For GQA + full_answers = samples.get("fullAnswer", ["" for i in range(len(sample_ids))]) + + # For AOKVQA & GQA & TextVQA + gt_answers = samples.get("gt_answers", ["" for i in range(len(sample_ids))]) + + # For AOKVQA + choices = samples.get("choices", ["" for i in range(len(sample_ids))]) + + for answer, sample_id, text_input, full_answer, gt_answer, choice, source in zip(answers, sample_ids, text_inputs, full_answers, gt_answers, choices, sources): + pred_qa_pairs.append({ + "question_id": sample_id, + "question": text_input, + "full_answer": full_answer, + "answer": answer, + "gt_ans": gt_answer, + "choice": choice, + "source": source}) + return pred_qa_pairs + + def evaluation(self, model, data_loader, cuda_enabled=True): + metric_logger = MetricLogger(delimiter=" ") + header = "Evaluation" + # TODO make it configurable + print_freq = 10 + + total_results = list() + for sub_data_loader in data_loader.loaders: + results = [] + for samples in metric_logger.log_every(sub_data_loader, print_freq, header): + + samples = prepare_sample(samples, cuda_enabled=cuda_enabled) + eval_output = self.valid_step(model=model, samples=samples) + + results.extend(eval_output) + + total_results.append(results) + + if is_dist_avail_and_initialized(): + dist.barrier() + + return total_results + + + def after_evaluation(self, val_result, split_name, **kwargs): + + final_metrics = dict() + for i in range(len(val_result)): + source = val_result[i][0]["source"] + result_file = self.save_result( + val_result[i], + result_dir=registry.get_path("result_dir"), + filename=f"{split_name}_vqa_result_{source}", + remove_duplicate="question_id", + ) + + if source in ['vqav2','okvqa']: + try: + metrics = self._report_metrics_coco_vqa(result_file=result_file, split=split_name, source=source) + except Exception as e: + metrics = None + print(f"Report Metrics {source} Error: {e}") + elif source in ['gqa','aokvqa','text_vqa']: + try: + metrics = self._report_metrics_gqa_aokvqa_textvqa(result_file=result_file, source=source) + except Exception as e: + metrics = None + print(f"Report Metrics {source} Error: {e}") + elif source in ['coco_cap','text_cap']: + try: + metrics = self._report_metrics_caption(result_file=result_file, split_name=split_name, source=source) + except Exception as e: + metrics = None + print(f"Report Metrics {source} Error: {e}") + else: + metrics = None + final_metrics[source] = metrics + + try: + agg_metrics_lst = [v["agg_metrics"] for k,v in final_metrics.items()] + final_metrics["agg_metrics"] = sum(agg_metrics_lst)/len(agg_metrics_lst) + except Exception as e: + print("Calculate agg metrics error... ", e) + final_metrics = None + + return final_metrics + + @dist_utils.main_process + def _report_metrics_coco_vqa(self, result_file, split, source='vqav2'): + """ + Use official VQA evaluation script to report metrics. + """ + metrics = {} + + if split in self.ques_files and split in self.anno_files: + vqa = VQA(self.anno_files[split][source], self.ques_files[split][source]) + vqa_result = vqa.loadRes( + resFile=result_file, quesFile=self.ques_files[split][source] + ) + + # create vqaEval object by taking vqa and vqaRes + # n is precision of accuracy (number of places after decimal), default is 2 + vqa_scorer = VQAEval(vqa, vqa_result, n=2) + logging.info("Start VQA evaluation.") + vqa_scorer.evaluate() + + # print accuracies + overall_acc = vqa_scorer.accuracy["overall"] + metrics["agg_metrics"] = overall_acc + + logging.info("Overall Accuracy is: %.02f\n" % overall_acc) + logging.info("Per Answer Type Accuracy is the following:") + + for ans_type in vqa_scorer.accuracy["perAnswerType"]: + logging.info( + "%s : %.02f" + % (ans_type, vqa_scorer.accuracy["perAnswerType"][ans_type]) + ) + metrics[ans_type] = vqa_scorer.accuracy["perAnswerType"][ans_type] + + with open( + os.path.join(registry.get_path("output_dir"), f"evaluate_{source}.txt"), "a" + ) as f: + f.write(json.dumps(metrics) + "\n") + + return metrics + + + @dist_utils.main_process + def _report_metrics_gqa_aokvqa_textvqa(self, result_file, source='gqa'): + """ + Validation of GQA & aokvqa + source = 'gqa' / 'aokvqa' + """ + # measuring accuracy compared to answer + results = json.load(open(result_file, "r")) + acc = [] + vqa_tool = VQAEval() + + for res in results: + + gt_ans = res["gt_ans"] + pred = res["answer"] + + pred = vqa_tool.processPunctuation(pred) + pred = vqa_tool.processDigitArticle(pred) + + # vqa_acc = 1 if pred == gt_ans else 0 + vqa_acc = 1 if gt_ans in pred else 0 + + acc.append(vqa_acc) + + accuracy = sum(acc) / len(acc) * 100 + metrics = {"agg_metrics": accuracy, "acc": accuracy} + + with open( + os.path.join(registry.get_path("output_dir"), f"evaluate_{source}.txt"), "a" + ) as f: + f.write(json.dumps(metrics) + "\n") + + logging.info(metrics) + + return metrics + + + @dist_utils.main_process + def _report_metrics_caption(self, result_file, split_name, source='coco_cap'): + """ + Use official COCO Cap evaluation script to report metrics. + """ + if source == 'coco_cap': + coco_gt_root = os.path.join(registry.get_path("cache_root"), "coco_gt") + eval = coco_caption_eval(coco_gt_root, result_file, split_name) + elif source == 'text_cap': + annotaion_file = "/mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextCap/TextCaps_0.1_val.json" + eval = textcaps_caption_eval(annotaion_file, result_file) + + agg_metrics = eval.eval["CIDEr"] + eval.eval["Bleu_4"] + log_stats = {split_name: {k: v for k, v in eval.eval.items()}} + + with open( + os.path.join(registry.get_path("output_dir"), f"evaluate_{source}.txt"), "a" + ) as f: + f.write(json.dumps(log_stats) + "\n") + + result = {k: v for k, v in eval.eval.items()} + result["agg_metrics"] = agg_metrics + + return result diff --git a/process_log/load_log.py b/process_log/load_log.py new file mode 100644 index 0000000..4ad4599 --- /dev/null +++ b/process_log/load_log.py @@ -0,0 +1,42 @@ +import json +import os +import pandas as pd +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cocap_tcap_raw_QformerMoE_Post_Route_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_15epo_0314/20240314230/" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_route/mix_coco_gqa_ao_cocap_tcap_raw_Qformer_base_lr5e5_10epo_0316/20240316192" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_10epo_0317/20240317165/" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_12epo_0317/20240317165/" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_10epo_0318/20240318145/" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_tcap_raw_Qformer_base_lr5e5_10epo_0318/20240318113" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_4ex4b_2loss_005_top6layer_textinqf_10epo_0319/20240319110" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_4ex4b_2loss_005_top6layer_textinqf_10epo_0319/20240319110/" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr6e5_3ex3b_2loss_005_top6layer_textinqf_10epo_0319/20240319105" +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_tcap_raw_Qformer_base_lr5e5_10epo_0320_instruct/20240320230/" +path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr3e5_3ex3b_2loss_005_top6layer_textinqf_10epo_0326/20240326134" +modes = ['gqa','okvqa','vqav2','aokvqa'] +file_name = dict() +data = dict() +for mode in modes: + file_name[mode] = os.path.join(path, f"evaluate_{mode}.txt") + tmp = list() + with open(file_name[mode], "r") as f: + for line in f: + tmp.append(json.loads(line)) + accs = [dd['agg_metrics'] for dd in tmp] + data[mode] = accs +print(data) +df = pd.DataFrame(data) +df.to_csv("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/process_log/results.csv",index=False) +print(df) +modes = ['coco_cap','text_cap'] +for mode in modes: + file_name[mode] = os.path.join(path, f"evaluate_{mode}.txt") + tmp = list() + with open(file_name[mode], "r") as f: + for line in f: + tmp.append(json.loads(line)["val"]) + df1 = pd.DataFrame(tmp) + df1.to_csv("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/process_log/results_{}.csv".format(mode),index=False) + print("\n",df1) + + + diff --git a/process_log/results.csv b/process_log/results.csv new file mode 100644 index 0000000..7b7729c --- /dev/null +++ b/process_log/results.csv @@ -0,0 +1,11 @@ +gqa,okvqa,vqav2,aokvqa +56.20130386388933,54.9,73.99,70.65502183406113 +59.98568929877563,56.83,75.05,71.87772925764192 +60.81252981396088,56.97,76.33,72.31441048034935 +61.217999681984416,58.72,76.9,72.13973799126637 +62.63316902528224,57.77,77.9,73.53711790393012 +62.63316902528224,58.7,78.62,72.31441048034935 +63.0863412307203,58.44,79.17,73.01310043668123 +63.99268564159644,58.3,79.67,73.1004366812227 +63.85752901892193,58.26,79.89,73.1877729257642 +64.05628875814915,58.36,79.96,73.27510917030567 diff --git a/process_log/results_coco_cap.csv b/process_log/results_coco_cap.csv new file mode 100644 index 0000000..ce73625 --- /dev/null +++ b/process_log/results_coco_cap.csv @@ -0,0 +1,11 @@ +Bleu_1,Bleu_2,Bleu_3,Bleu_4,METEOR,ROUGE_L,CIDEr +0.823223287530701,0.6717046486084919,0.5289578677489075,0.40968911321073265,0.29959733857127835,0.6000811689013733,1.3453305422681376 +0.8223586805364383,0.6761940235369243,0.5372021051535613,0.42002107530341415,0.3007070731886293,0.6033562057362387,1.3625435411445188 +0.8268533010152048,0.6768505208938964,0.5354603117460063,0.41692252223016474,0.3025309435581166,0.6047665704453451,1.372728338529115 +0.8330185365456568,0.6839215475169012,0.543703824076164,0.42514888177425736,0.3056707793309544,0.6059406486352524,1.3861562304957278 +0.82851059467086,0.6809176651975785,0.540587141996957,0.4231733852045105,0.30531255145802344,0.6074083433970782,1.3834468097216504 +0.826759683808046,0.6786303228237849,0.5393123491853253,0.4218864249035804,0.30435744091261596,0.6061024359922585,1.373555774390539 +0.8276490281491424,0.6807227995567375,0.5428624362839755,0.4269084803683001,0.30608939915890326,0.6078544738015538,1.3817698541866794 +0.8294752997477592,0.684006846582908,0.5462002712862497,0.43036414820903174,0.3072188868161371,0.6112755876824707,1.3946012917260822 +0.8295967394767231,0.6833360908793101,0.5455277195646591,0.4292003829675252,0.3075074477217475,0.6103970688371728,1.392473266614681 +0.8311746625208477,0.685154908394367,0.5473453025692251,0.4310962369321893,0.3077283192093123,0.6112527230048272,1.3960092087389069 diff --git a/process_log/results_text_cap.csv b/process_log/results_text_cap.csv new file mode 100644 index 0000000..b825d8c --- /dev/null +++ b/process_log/results_text_cap.csv @@ -0,0 +1,11 @@ +Bleu_1,Bleu_2,Bleu_3,Bleu_4,METEOR,ROUGE_L,CIDEr +0.7354185510459915,0.5532517544390539,0.40861793000486973,0.29983800193167554,0.24116643313333005,0.49331369213983123,1.0410874398111594 +0.7337559429476805,0.5532662736631931,0.4093585176779233,0.2995878219607148,0.24114315075166085,0.49493696747386196,1.0409495699510138 +0.7320402508385379,0.5518304887400776,0.4110219411525107,0.3038313373404956,0.24291044377257376,0.496058725623593,1.0575894770967331 +0.735500933603824,0.557825793289975,0.41590254371174706,0.30692236406930773,0.2383527664105175,0.49362765900900546,1.044158754983385 +0.7398393194706776,0.5620893308030489,0.41820732364444946,0.3078230351155729,0.24483177062749348,0.49847300976639547,1.0688814357108025 +0.7371954616083516,0.5585083902477329,0.4158548739958792,0.30707857737975475,0.244730703390385,0.49852757085109406,1.0729374604921766 +0.743162583518909,0.5634215669347339,0.42011923174763044,0.311173553245093,0.2441994031592775,0.49809622365158585,1.0774442793527572 +0.7394105997457735,0.5596190520647151,0.4174127065616514,0.3086643930692146,0.24516092911412224,0.4984150187439041,1.0745130576356021 +0.7401475669028302,0.5612229049005432,0.4184684096329846,0.3101423931360949,0.2459814560570504,0.49998501882495094,1.0809003006778692 +0.7420413230468877,0.5634982579203833,0.4203003411022801,0.3115589386663839,0.245912807863165,0.4989393568537978,1.0805529900845852 diff --git a/prompts/alignment.txt b/prompts/alignment.txt deleted file mode 100644 index 38ae75a..0000000 --- a/prompts/alignment.txt +++ /dev/null @@ -1,4 +0,0 @@ - Describe this image in detail. - Take a look at this image and describe what you notice. - Please provide a detailed description of the picture. - Could you describe the contents of this image for me? \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0d7634c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,49 @@ +contexttimer +decord +diffusers==0.14.0 +einops>=0.4.1 +fairscale==0.4.4 +ftfy +iopath +ipython +omegaconf==2.3.0 +opencv-python-headless==4.5.5.64 +opencv-python==4.7.0.72 +opendatasets +webdataset==0.2.48 +packaging +pandas +plotly +pre-commit +pycocoevalcap +pycocotools +python-magic +scikit-image +sentencepiece +spacy +streamlit +timm==0.6.13 +torch==2.0.0 +torchaudio +torchvision +huggingface-hub==0.18.0 +matplotlib==3.7.0 +psutil==5.9.4 +pyyaml==6.0 +regex==2022.10.31 +tokenizers==0.13.2 +tqdm==4.64.1 +transformers==4.30.0 +peft==0.2.0 +sentence-transformers +gradio==3.47.1 +accelerate==0.20.3 +bitsandbytes==0.37.0 +wandb +wheel +visualizer +tensorboard +kmeans_pytorch +visual_genome +gpustat +torchviz \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6a67455 --- /dev/null +++ b/setup.py @@ -0,0 +1,36 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +from setuptools import setup, find_namespace_packages +import platform + +DEPENDENCY_LINKS = [] +if platform.system() == "Windows": + DEPENDENCY_LINKS.append("https://download.pytorch.org/whl/torch_stable.html") + + +def fetch_requirements(filename): + with open(filename) as f: + return [ln.strip() for ln in f.read().split("\n")] + + +setup( + name="PromptMoE", + version="1.0.1", + author="Hanzi Wang", + description="PromptMoE & QformerMoE Based on LAVIS", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + keywords="Vision-Language, Multimodal, Image Captioning, Generative AI, Deep Learning, Library, PyTorch", + license="3-Clause BSD", + packages=find_namespace_packages(include="lavis.*"), + install_requires=fetch_requirements("requirements.txt"), + python_requires=">=3.7.0", + include_package_data=True, + dependency_links=DEPENDENCY_LINKS, + zip_safe=False, +) \ No newline at end of file diff --git a/test.txt b/test.txt index e69de29..516c092 100644 --- a/test.txt +++ b/test.txt @@ -0,0 +1,360 @@ +tmp_name = [name for name, p in model.named_parameters() if (p.requires_grad and '10.expert' in name)] + +tmp = [p for name, p in model.named_parameters() if (p.requires_grad and '10.expert' in name)] + +tensor([[-1.4032e-02, 3.7242e-03, 8.4997e-03, -3.4016e-03, -6.4855e-03, + 4.3595e-02, 3.4423e-02, -8.6274e-03, -1.9702e-02, 9.1813e-03, + 1.1643e-02, 2.3939e-02, -2.0908e-02, 3.4555e-03, 9.1636e-03, + 1.5413e-02, 2.4148e-02, -1.0880e-03, 1.1193e-02, -1.3591e-02, + 9.3484e-03, 1.5999e-02, -9.6086e-04, 3.8322e-02, -8.0687e-03, + -1.4056e-02, 3.9486e-02, 3.5167e-02, -9.3226e-03, -1.0493e-02, + -2.5795e-02, -9.7541e-03, 4.4437e-03, 7.7226e-03, 7.5210e-03, + -1.3526e-02, -5.0316e-03, -1.1149e-02, 6.0583e-03, 2.0564e-02, + -6.4477e-03, 1.4170e-02, -3.7847e-02, 1.1780e-02, 1.3321e-02, + -8.2501e-03, -1.0298e-02, 1.4805e-02, -1.2432e-02, -1.9159e-02, + -5.7095e-04, -3.8618e-02, -2.4230e-02, -1.4991e-03, -1.4114e-02, + -1.5365e-02, 1.5640e-02, -4.8623e-02, -2.9991e-02, 1.2796e-02, + -4.9917e-03, 2.3846e-03, 7.7368e-03, 1.2913e-02, 1.5300e-02, + 8.5125e-03, 1.1582e-02, 8.1161e-03, 4.2259e-03, 7.6109e-03, + -2.0747e-02, -3.5099e-03, 2.2282e-02, 5.0493e-02, -1.7849e-02, + -3.7106e-02, -1.4944e-02, -1.4582e-02, -2.2458e-02, -4.6173e-05, + -8.1270e-03, 1.9037e-02, -2.0086e-02, 3.0980e-03, -9.3947e-03, + 1.3054e-02, 2.3203e-02, -9.9304e-03, -2.6038e-02, 1.8679e-02, + 9.2081e-03, -2.1770e-02, -1.6568e-03, -3.6503e-02, 2.0054e-02, + 1.2886e-02, -1.8021e-02, 3.4457e-02, -1.3704e-02, -6.1498e-03, + -8.6769e-03, 1.5024e-02, -1.3875e-02, 1.7416e-02, -1.1178e-02, + -2.4088e-02, -1.7802e-02, 3.3326e-02, -1.1216e-02, -8.6330e-03, + -5.5359e-03, -1.1939e-02, -1.7777e-02, -2.8666e-02, -3.8280e-02, + 4.2682e-02, 1.4946e-02, 9.6427e-03, 8.2754e-03, -1.0516e-03, + 2.9560e-02, 2.4552e-03, -4.8354e-02, 1.5568e-02, 2.5881e-02, + -1.7354e-02, -3.1232e-02, 2.3683e-02, -2.3239e-02, 2.2966e-02, + 5.6349e-03, -8.7595e-03, 1.5173e-02, 2.7660e-02, -4.3304e-03, + -2.5330e-02, -2.1795e-02, 1.6856e-02, -2.1587e-04, 2.3707e-02, + -2.3667e-02, 3.5378e-02, -7.9245e-03, 7.1029e-04, -3.2800e-02, + -1.5402e-03, -8.5634e-03, -1.1356e-02, -2.1935e-03, -1.8854e-02, + -1.9705e-03, -3.8333e-02, 2.9131e-02, -4.4470e-02, -2.0893e-03, + 1.2937e-02, -1.7116e-02, 2.7778e-02, 1.0311e-02, -6.4017e-03, + 3.7647e-02, -1.9953e-02, -5.3925e-03, 3.6978e-02, -1.5534e-02, + 1.2241e-02, 1.3597e-02, 2.0703e-03, 2.4213e-03, 9.2604e-03, + 6.6108e-03, -5.8213e-03, 9.8167e-03, -9.8300e-04, -1.0236e-02, + 2.9581e-02, 1.0987e-02, 2.0046e-02, -1.0500e-02, -3.2221e-03, + -2.6303e-02, 1.3688e-02, -2.2529e-02, -5.7654e-03, 1.1784e-02, + 1.6221e-02, 2.8743e-02, 5.7565e-03, 1.8129e-02, 1.5140e-02, + -1.1748e-02, -1.7528e-02, 4.7977e-02, 1.5568e-02, 4.7030e-04, + 3.2757e-03, 1.6631e-02, 1.9986e-02, -7.3463e-03, 1.1435e-02, + -1.4739e-02, -3.2959e-03, -2.8770e-03, 2.9260e-02, 1.7007e-02, + 3.0611e-02, 2.2102e-02, -3.3819e-02, -1.9403e-02, 2.5524e-02, + 3.0738e-02, -1.9951e-02, -1.4553e-02, -1.5796e-02, -2.3143e-02, + -2.8826e-02, 2.4739e-02, -5.8602e-03, 4.1871e-02, 5.0821e-04, + 3.3493e-02, 2.3524e-02, 2.3191e-02, 9.0416e-03, 3.3262e-02, + -1.6805e-02, 1.1545e-02, -1.7195e-02, -3.8696e-02, -8.4358e-04, + -8.1605e-03, 3.1372e-03, 1.0726e-03, 1.0865e-03, 1.0760e-02, + -5.2421e-03, 1.3039e-02, 3.6873e-04, 1.0464e-02, -1.1544e-02, + -2.2775e-02, -4.8439e-02, -1.0711e-02, 4.4236e-03, 2.0351e-02, + 2.4479e-03, -1.9968e-02, -2.2941e-02, -2.0486e-02, -1.9528e-02, + -2.3176e-02, -3.2731e-03, 1.1789e-02, 2.0921e-02, 2.9809e-03, + -8.8507e-03, -3.5716e-02, 8.8418e-03, 5.3665e-05, -1.1288e-02, + -7.5571e-03, 2.1053e-02, -3.7381e-03, -4.0165e-03, -2.2628e-03, + 3.7554e-03, -1.6597e-02, 7.6946e-03, -3.2689e-02, 2.2016e-02, + 5.5122e-03, 4.5455e-02, 6.7586e-03, 1.5714e-02, 5.2125e-03, + 3.9596e-03, 1.8134e-02, 1.5834e-03, -1.6239e-02, -1.3889e-02, + -2.3522e-02, 1.4738e-02, 5.5867e-03, -7.0727e-03, -2.8140e-03, + 1.6849e-02, -3.1327e-02, -3.2443e-02, 4.7851e-03, 1.2980e-02, + -2.0014e-04, -9.9475e-03, 8.0657e-03, 1.9468e-02, -1.5774e-02, + 1.7017e-02, -8.7196e-03, -4.0681e-03, -6.9754e-03, -2.2007e-02, + -6.6217e-03, -1.8219e-02, 4.2186e-02, -5.6621e-03, -9.3449e-03, + -1.1662e-02, 2.8700e-02, -9.0654e-03, 3.1569e-02, -2.9825e-03, + -3.8198e-02, -5.2723e-02, -4.8325e-02, -2.7871e-03, 5.1127e-03, + 1.4511e-02, 9.3245e-03, -2.3339e-02, -8.6658e-03, 1.5276e-02, + -1.5823e-02, -3.4476e-03, 1.4601e-02, 6.3504e-03, -1.4307e-02, + 2.2817e-02, 2.1998e-02, 1.7330e-02, -2.4448e-02, 4.0178e-03, + 3.2280e-03, -1.2721e-02, 1.9661e-02, 7.5263e-03, 2.0245e-02, + 4.5525e-02, -1.5658e-02, -4.0676e-02, 9.3160e-03, 1.1920e-02, + -1.9317e-02, 1.7848e-02, -5.8601e-03, 1.1786e-03, 8.3864e-03, + -1.8341e-02, 2.5985e-02, -1.1387e-02, -1.5069e-02, -2.8097e-02, + 2.4966e-02, 1.4790e-02, 2.0424e-02, -1.3062e-02, 3.1314e-02, + 1.7811e-02, 7.2393e-03, 1.4413e-02, -1.2746e-02, 3.1039e-02, + -1.1697e-02, -1.4826e-02, -8.8397e-03, 1.5157e-02, -1.5855e-02, + -1.8157e-03, 1.3024e-02, -1.8902e-03, 2.5212e-02, -3.4886e-02, + 4.3029e-02, -4.0842e-02, 1.1362e-02, -1.4654e-02, -1.3337e-02, + -3.1832e-02, 3.6222e-03, 8.2804e-03, -1.4269e-02, 2.8399e-03, + -1.2008e-02, 2.4685e-02, -4.3070e-03, 6.3163e-03, -1.3517e-02, + -1.3807e-02, 2.4617e-02, 2.1453e-02, 4.7332e-03, 9.1636e-03, + -1.2881e-02, 1.9077e-02, 1.7571e-04, -5.2817e-03, -2.8821e-02, + 5.8223e-03, -3.0979e-02, 2.4609e-02, 3.6666e-02, -1.0950e-02, + 2.0421e-02, -2.6378e-03, 3.1825e-02, -9.6689e-04, -2.8398e-02, + -2.7513e-02, 1.6946e-02, -2.4110e-02, -1.3575e-02, -1.3443e-02, + 8.4217e-03, 2.6754e-02, -2.3309e-03, -2.5086e-02, 1.1844e-02, + 1.4152e-02, 1.2989e-02, -5.7336e-03, 4.7391e-03, 3.4106e-02, + 1.0142e-02, -1.8029e-02, -1.5410e-04, -1.3548e-02, 9.1742e-03, + -3.0150e-02, 1.5666e-02, 4.3049e-03, 1.6273e-02, 2.0672e-02, + -1.2458e-02, 4.5496e-02, 3.2131e-02, -3.0967e-03, 2.1891e-02, + 2.5524e-02, -1.1998e-02, -1.8866e-03, -1.0945e-02, 5.9930e-03, + -8.4233e-03, -8.9095e-03, -1.8261e-02, 1.9308e-02, -1.9728e-02, + -1.4216e-02, 1.4952e-02, 5.7355e-04, -2.4753e-02, -1.0948e-02, + 1.0965e-02, 1.3607e-03, 3.4974e-02, -4.1396e-03, 2.5519e-02, + 1.0364e-02, -1.5851e-02, -4.9224e-03, 1.0903e-02, -1.0523e-04, + 3.1355e-02, -1.5105e-02, 5.6972e-03, -8.4078e-03, -1.9868e-02, + 1.7186e-03, 2.9396e-02, -4.1439e-02, 1.4124e-02, -3.7745e-03, + 3.3007e-02, 8.0368e-04, 8.5574e-03, 1.7269e-02, 1.1955e-02, + 8.8142e-03, -1.3123e-02, 1.6817e-02, -1.5456e-02, -1.3868e-02, + 2.4139e-02, -9.1566e-03, -1.8477e-02, -4.7972e-03, -6.8459e-03, + 1.6818e-02, 3.1645e-03, -3.0901e-02, -5.6036e-03, -1.4758e-02, + 2.0473e-02, -7.5411e-05, 2.0673e-03, -7.0061e-03, 9.5544e-03, + 1.6600e-02, -1.7315e-02, -2.0168e-02, -5.3008e-03, 2.0206e-02, + 2.4209e-03, 2.1205e-02, -8.9188e-03, -4.1350e-04, -1.0638e-02, + 1.3705e-02, 9.5925e-05, 3.8877e-02, 3.2884e-02, -2.7730e-03, + 1.0052e-02, 1.9311e-02, 1.1341e-02, -1.2988e-02, -1.7157e-02, + 3.2095e-02, -1.8493e-02, -9.2551e-03, -2.6509e-03, -1.1130e-02, + 1.6581e-02, 1.0216e-02, 1.3687e-02, 1.1860e-02, -3.0462e-03, + -1.2082e-02, 2.8502e-03, -1.2620e-02, 8.8330e-03, 1.7357e-02, + 1.8383e-02, -2.3130e-02, -3.2654e-02, 1.2853e-02, -7.8144e-03, + 1.9418e-04, 3.8635e-03, 4.9333e-02, 1.9350e-02, -2.0643e-02, + 8.4650e-04, 5.0242e-02, 1.6576e-02, -8.9166e-03, -5.8805e-03, + -4.1484e-02, 9.3217e-03, -1.1292e-02, -8.7944e-03, -3.3190e-03, + 5.7970e-03, -6.6078e-03, -2.4052e-02, -5.6347e-03, 8.4539e-03, + 1.9250e-02, 7.9559e-03, -3.0055e-03, -3.0398e-04, 2.7007e-02, + 3.1046e-03, 1.8332e-02, 5.5470e-03, 6.6815e-03, 1.1466e-02, + 1.9738e-02, 1.2176e-02, -2.0220e-02, 8.6928e-03, 4.2451e-03, + 4.4517e-03, -5.1524e-03, 1.0805e-02, -2.1935e-02, -1.7575e-02, + -1.2529e-02, -2.2191e-02, -1.0854e-02, -9.4462e-03, -2.9102e-02, + 2.6752e-02, -1.0919e-02, -2.6724e-02, 8.3694e-04, 2.9832e-03, + 1.4416e-02, -2.9906e-02, 2.3556e-02, -6.6624e-03, 2.6671e-02, + -3.6474e-02, 1.7237e-02, -2.5176e-02, 6.5560e-03, -2.6062e-02, + -2.3838e-02, 3.0629e-02, 2.5382e-02, 1.2302e-02, -1.1665e-02, + -7.0603e-03, 1.9931e-02, 2.3401e-02, -2.6047e-03, -2.7728e-02, + -1.7212e-02, 2.3061e-02, -2.5961e-02, 3.9764e-04, -2.9022e-02, + -1.5546e-03, 4.5519e-03, 2.3589e-02, -3.5005e-02, 4.1890e-03, + -1.5586e-02, 1.2389e-02, -2.1045e-02, 1.6377e-03, -1.1328e-02, + 1.0195e-02, 6.4322e-03, -3.8431e-02, 2.2918e-02, -4.0123e-03, + 6.6680e-02, 4.1135e-02, -1.5031e-02, -1.3550e-02, -2.2566e-02, + -2.3622e-03, -2.9323e-02, 2.1756e-02, 1.8399e-03, -4.2460e-03, + -1.5128e-03, -2.4731e-02, 1.8663e-02, 1.3469e-02, -1.3897e-02, + 2.6399e-02, -8.0740e-03, -4.6753e-03, 3.9857e-02, 6.2364e-03, + 2.2371e-03, 2.1501e-03, 5.9443e-02, 1.3574e-02, 7.6483e-03, + -6.2290e-03, 1.4324e-02, 1.2572e-02, 2.7331e-02, -6.0165e-03, + -5.9154e-03, -3.7000e-02, 1.4001e-02, 1.2869e-02, -2.8854e-02, + -9.4147e-03, 8.3965e-03, -1.4530e-03, -7.4215e-03, 9.0369e-03, + -2.4612e-02, 2.0625e-02, 2.2329e-02, -1.5216e-02, 1.4947e-03, + -3.6020e-02, -2.0702e-02, -4.0410e-02, -1.3157e-02, -1.5085e-02, + 1.2911e-02, -2.7552e-02, -2.9781e-02, -4.7424e-03, 2.0521e-02, + -4.0043e-02, -4.8763e-02, -1.3175e-02, 2.6802e-02, 2.8869e-02, + 6.5014e-03, -2.3213e-02, 1.4438e-02, -7.6318e-03, -1.9928e-03, + 1.8509e-03, 2.9728e-03, 1.5225e-02, -2.9405e-03, -7.2875e-03, + 2.9562e-05, -1.8661e-02, 9.1341e-03, -2.4919e-02, 2.9786e-02, + 9.5186e-03, 1.5435e-02, -1.1080e-02, 1.1192e-02, -2.7315e-03, + 6.9769e-05, -1.5392e-02, 4.9892e-03, 7.9857e-03, 2.0063e-02, + -2.0283e-02, -1.2596e-02, -4.1985e-04, -6.9686e-03, -5.4704e-02, + -1.9142e-02, 9.9706e-03, 2.3217e-02, -5.0579e-03, -4.9132e-02, + 2.0023e-02, -2.6238e-02, 1.0709e-02, 2.1528e-02, -1.6390e-03, + -6.7829e-03, 1.3211e-02, -9.6793e-03, 1.3130e-02, -1.2878e-02, + 1.7365e-02, 1.2509e-02, 1.2986e-03, -3.9292e-02, 9.5784e-03, + -8.0514e-03, -3.5619e-02, -3.2298e-02, 6.5933e-04, 9.9298e-03, + 3.7268e-02, -3.4047e-02, -7.8385e-03, 2.3999e-02, 1.0386e-02, + 1.7853e-02, -1.0122e-04, 5.2483e-04, -7.3150e-03, 1.0818e-02, + 1.6245e-02, -3.5619e-02, -9.9190e-03, 4.0132e-03, 9.7788e-03, + 2.7039e-02, -4.7858e-02, -2.0010e-02, -2.3702e-02, 7.8376e-04, + -2.5326e-02, 1.1698e-02, -1.3041e-02, 3.8634e-03, 9.3083e-03, + 4.8204e-03, 3.9503e-02, -4.1356e-03]], requires_grad=True) +model.Qformer.bert.encoder.layer[10].experts.gate.weight + +layer 11 +0: +model.Qformer.bert.encoder.layer[11].output.dense.weight.grad +model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad + +nan: +model.Qformer.bert.encoder.layer[11].attention.output.dense.weight.grad +model.Qformer.bert.encoder.layer[11].attention.self.query.weight.grad +model.Qformer.bert.encoder.layer[11].experts.intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[11].experts.output_query.dense.weight.grad + +None: +model.Qformer.bert.encoder.layer[11].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[11].output_query.dense.weight.grad + +layer 8 +0: +model.Qformer.bert.encoder.layer[8].experts.experts[0].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[2].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[0].output_query.dense.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[2].output_query.dense.weight.grad + +nan: +model.Qformer.bert.encoder.layer[8].experts.experts[1].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[1].output_query.dense.weight.grad +(Qformer)model.Qformer.bert.encoder.layer[8].intermediate_query.dense.weight.grad + +None: +model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad == None +model.Qformer.bert.encoder.layer[8].experts.gate.weight.requires_grad == True + + +model.Qformer.bert.encoder.layer[6].experts.gate.weight +Qformer.bert.encoder.layer.6.experts.gate.weight + +tensor([[-0.0089, -0.0123, -0.0168, ..., -0.0072, 0.0295, -0.0167], + [ 0.0305, 0.0277, -0.0215, ..., 0.0149, 0.0016, -0.0415], + [ 0.0199, 0.0151, 0.0237, ..., 0.0007, 0.0023, 0.0167]], + requires_grad=True) + +tensor([[-0.0089, -0.0123, -0.0168, ..., -0.0072, 0.0295, -0.0167], + [ 0.0305, 0.0277, -0.0215, ..., 0.0149, 0.0016, -0.0415], + [ 0.0199, 0.0151, 0.0237, ..., 0.0007, 0.0023, 0.0167]], + requires_grad=True) + + +tensor([[ 4.5972e-02, -1.5231e-02, -6.9533e-03, 3.2431e-02, -7.9703e-03, + 1.5567e-02, 2.9619e-03, -2.2609e-04, 1.8580e-02, -2.8783e-02, + 1.3093e-02, -1.0594e-02, 1.1918e-02, 4.4701e-02, 2.0108e-02, + -1.1011e-03, -8.2449e-03, 8.8876e-03, 4.6096e-03, 2.3274e-02, + -9.2557e-03, 2.5704e-03, 1.8919e-02, -5.3251e-03, -3.2665e-03, + -3.2663e-02, -5.6756e-02, -2.3400e-02, 1.3674e-02, -6.6185e-03, + 1.4429e-03, 1.2354e-02, 2.5934e-03, 2.1895e-02, -1.9793e-02, + 1.5497e-03, 4.3056e-03, -4.0023e-02, 9.8740e-03, 3.8631e-03, + -1.2918e-02, -3.6782e-02, -9.8365e-03, 3.2182e-02, 2.3729e-02, + 2.3509e-03, 1.8473e-02, 1.5583e-02, -1.1029e-02, -1.0738e-02, + -3.0278e-02, -9.8731e-03, -1.0500e-02, 7.9832e-05, -1.0345e-02, + 8.2803e-03, -5.9923e-03, -1.2669e-02, 1.2065e-03, 7.5720e-03, + -1.9286e-02, 4.0070e-02, 3.6221e-03, -1.7486e-02, 2.1725e-02, + -3.3231e-02, 7.3948e-03, -1.0924e-02, 3.1448e-02, 1.2101e-02, + 6.1737e-03, -2.0851e-02, -3.7964e-02, 8.0938e-03, -8.8967e-03, + 2.5925e-02, -7.8063e-04, 8.6102e-03, 2.7370e-02, 1.2323e-02, + 4.0606e-03, 3.9316e-02, -1.0837e-02, -2.6835e-03, 3.1941e-03, + -1.2017e-02, -2.3022e-02, 8.3533e-03, -2.2668e-02, 1.4438e-02, + -2.3664e-02, 4.5595e-02, -1.0962e-02, 1.7547e-02, -1.6739e-03, + 1.2048e-02, 2.0544e-02, 2.8837e-02, -1.6736e-02, 2.1207e-02, + 8.7612e-03, 2.8757e-02, -3.8561e-03, 8.4050e-03, -1.1503e-02, + -5.8332e-03, 1.5734e-02, -1.0773e-02, 7.5827e-03, 6.5794e-03, + 2.4291e-02, 2.6811e-02, 1.1681e-02, -3.3246e-02, 4.5776e-03, + -9.0628e-04, -2.9400e-02, 4.2933e-03, 1.5885e-03, 5.5757e-02, + 7.5518e-03, 1.0099e-02, 5.3507e-03, -3.0182e-02, 2.0830e-02, + 1.0102e-02, -9.3074e-03, 3.1161e-02, -1.7800e-02, -4.4445e-03, + -3.1503e-02, 2.3028e-02, 8.3472e-03, 7.4444e-03, 1.8838e-02, + -1.1977e-02, -2.6713e-02, 1.1364e-02, 8.3522e-04, 3.3736e-03, + 6.9425e-03, -2.0632e-02, 1.8155e-02, -2.1711e-02, -3.4703e-02, + -3.6268e-03, -4.8810e-03, -2.8142e-02, -1.5781e-02, -3.3166e-02, + -2.9910e-02, -9.7459e-03, -6.7474e-03, 1.7988e-02, 9.0176e-03, + 1.9452e-02, 4.2009e-02, 1.7217e-02, 1.4959e-02, -1.6552e-02, + -3.8206e-03, -2.4889e-02, 7.7993e-03, -1.9285e-02, -1.9770e-02, + 2.6936e-02, -5.0484e-03, -2.5117e-02, -2.3122e-02, 1.3754e-02, + 1.6025e-02, -9.1569e-03, -2.0068e-02, -1.6013e-02, -2.1775e-02, + -2.4154e-02, 6.2840e-03, -1.3684e-02, 2.5378e-02, -1.3166e-02, + -1.2201e-02, 1.0011e-02, -8.2324e-03, -5.6623e-03, -1.0383e-02, + -1.6251e-02, 1.0723e-02, -3.0207e-03, -6.9374e-03, -2.3161e-03, + -2.0850e-03, -3.4216e-02, 3.3997e-02, 3.7444e-02, -3.4273e-02, + 1.5051e-02, -9.5605e-03, -2.6979e-03, 1.8848e-02, 2.3090e-02, + 1.9669e-02, -3.9656e-02, 1.0453e-02, 5.2222e-03, -7.2493e-03, + 1.4122e-02, 5.6583e-04, -1.3991e-02, 4.0975e-02, 1.3947e-02, + 4.6919e-03, 7.9121e-03, 2.6936e-02, 1.2338e-02, 1.9048e-02, + 7.7740e-03, -6.4494e-03, -5.2965e-02, 8.1929e-03, -1.3503e-02, + 3.7466e-03, -3.3504e-02, -8.1192e-03, 1.0463e-02, -2.1568e-02, + 1.0076e-02, -1.3420e-02, -6.3353e-04, 7.4253e-03, 2.2281e-02, + 5.2829e-03, 1.4102e-02, 1.4427e-02, 1.6331e-02, -2.3305e-04, + -4.4875e-02, 6.5300e-03, 2.4963e-02, 2.2141e-03, 3.9830e-02, + 1.1405e-02, 8.6810e-03, -2.0404e-03, -1.8579e-03, 1.4765e-02, + 5.4752e-03, -1.3364e-02, -1.3082e-03, 1.5873e-03, 1.9309e-02, + 3.4367e-02, 1.8459e-02, -1.1323e-02, -1.8764e-02, -1.5370e-02, + 3.6180e-03, 2.8253e-02, -1.6867e-03, 3.5884e-03, -2.1952e-02, + -1.5026e-02, -2.1070e-02, -1.2149e-02, 1.1162e-02, -3.0343e-02, + -4.1372e-02, 1.0880e-02, 2.2365e-02, 1.2896e-02, 2.9694e-02, + -8.4248e-03, -7.8876e-03, -6.7049e-03, 2.3700e-02, 4.7528e-03, + -7.8350e-03, -5.9220e-03, 3.8396e-02, -4.1598e-02, -2.3161e-03, + 1.3419e-02, 7.1029e-03, 1.4195e-02, -1.1124e-02, 1.5812e-02, + -1.9789e-02, -2.3883e-02, -8.2788e-04, 1.4670e-02, -2.1482e-02, + -1.1182e-02, -1.6532e-02, -8.0637e-03, -3.7822e-02, 3.9402e-02, + -1.4097e-03, -7.6648e-03, -3.7156e-02, 2.5791e-02, 6.1038e-03, + -6.3429e-03, 3.2865e-03, 3.6277e-02, 9.4312e-03, -2.1003e-02, + -3.6885e-03, 1.7147e-02, -1.3079e-02, -4.9414e-02, -3.2066e-02, + 1.4835e-02, -2.9742e-02, 1.8358e-02, -2.1733e-02, 3.0256e-03, + 1.7825e-02, 1.1079e-02, 1.1619e-02, -2.3680e-02, -7.8721e-03, + 2.4456e-03, 4.3608e-02, -4.5674e-03, -3.6818e-02, 3.3952e-02, + 3.3108e-02, -3.1665e-03, -2.3468e-03, 1.5091e-02, 7.0856e-03, + 1.1723e-02, -2.0713e-02, -6.9180e-03, 3.7929e-02, 3.7671e-03, + 4.6663e-02, 9.5301e-03, 1.2638e-02, -6.5623e-03, -3.1771e-03, + -1.7568e-02, 1.8711e-03, -1.2310e-02, 2.1518e-02, 4.3408e-03, + -6.7171e-03, -5.0451e-03, 2.6870e-02, -1.9832e-02, 7.0422e-03, + 1.1274e-02, -2.4637e-02, -4.8450e-03, 2.1892e-02, -2.6059e-02, + 1.5605e-02, -1.1617e-02, -1.9273e-02, -8.6735e-04, -9.8002e-04, + -1.8553e-02, 2.1239e-02, 2.1078e-02, -1.2091e-02, 9.7025e-03, + 1.3426e-02, -1.1710e-02, -2.2242e-03, 6.4133e-03, -1.4820e-02, + 1.4682e-02, 3.0679e-02, 1.1526e-02, 1.0072e-02, -1.1572e-02, + 2.6128e-02, 4.0879e-03, -1.7936e-02, 1.3715e-02, -2.3667e-02, + 2.0419e-03, -1.6887e-02, 1.2595e-02, -2.1988e-02, -2.3777e-02, + -1.0399e-02, 2.4868e-03, -1.2265e-02, -1.8543e-02, 3.4672e-02, + 2.1114e-02, 2.0523e-02, 7.6818e-03, 2.9282e-02, -5.9593e-03, + -2.8496e-02, 2.8482e-03, 3.6874e-04, 4.7455e-02, -2.9770e-02, + -2.0684e-02, -2.0749e-02, -5.7681e-02, -2.6175e-03, -2.4488e-02, + -5.2550e-03, -7.1191e-03, 3.8192e-02, 4.3438e-02, 5.4181e-03, + 2.8392e-02, 1.9493e-02, -3.5262e-02, 1.4839e-02, 4.6481e-03, + 1.7219e-02, 2.0160e-02, 4.9998e-03, 2.1316e-02, -8.7929e-04, + -2.1542e-02, 3.9816e-03, 1.5879e-02, 9.9231e-03, 1.3962e-02, + -5.3418e-03, 3.9857e-02, 2.0997e-02, -2.1291e-05, 1.8133e-02, + -1.2472e-02, 4.9437e-03, -1.5099e-02, 4.8860e-02, 6.1980e-03, + 2.0197e-02, 1.3141e-04, -3.1087e-03, -2.2718e-03, 2.3804e-02, + 6.0726e-03, -2.0485e-02, -2.0514e-02, -2.7679e-02, -3.0412e-02, + -1.7661e-02, -1.7462e-02, 7.5216e-03, 2.2238e-02, 1.1413e-03, + 2.6647e-02, -2.3855e-02, 2.2652e-03, -4.3256e-03, -9.3274e-03, + 2.5149e-02, 6.8432e-03, 4.2664e-03, 3.8221e-02, 7.7480e-03, + 8.7203e-03, -1.2851e-03, -1.1325e-02, -1.0650e-02, -2.8079e-02, + -1.5375e-02, 2.2630e-02, -4.3439e-03, 1.3493e-02, -1.8223e-02, + 9.9750e-03, -2.4560e-02, 1.0904e-03, -3.1198e-02, 4.7331e-03, + 1.6713e-02, -1.7653e-02, -3.8674e-02, 1.5458e-02, 4.0555e-02, + 6.9451e-03, 1.1988e-03, 8.0718e-04, 3.9985e-03, -2.2781e-02, + 8.1173e-04, 2.0106e-02, -1.2800e-02, -1.2961e-02, -2.1273e-02, + -4.4104e-05, -3.6080e-02, -1.9392e-02, 3.2862e-02, -5.6041e-03, + 2.3288e-02, -4.6795e-02, 1.7282e-02, 5.7052e-03, 2.2405e-02, + 1.9871e-03, -1.4333e-02, 5.3773e-03, 4.3568e-02, 9.8980e-03, + -1.9403e-03, 1.8981e-02, -2.5712e-02, -3.3621e-03, 2.9886e-02, + 1.3326e-03, 1.1318e-02, -3.3238e-03, -1.5494e-02, -3.0565e-02, + 1.7137e-02, -2.7874e-02, -1.1257e-02, 3.2250e-02, -2.5293e-02, + -3.0693e-03, -2.7787e-02, 1.4931e-02, 2.4202e-03, -4.0572e-03, + 5.0273e-03, 9.7496e-03, 2.2601e-02, 3.2389e-02, -1.1910e-02, + 9.1037e-03, 5.6000e-02, -1.9640e-02, 1.5469e-02, -3.3027e-02, + 1.4839e-02, 2.5071e-02, -1.2687e-02, -1.3466e-02, 1.9031e-02, + -7.3403e-03, -1.5207e-02, -1.4486e-02, 2.0678e-02, -4.1996e-02, + 1.0585e-02, 3.6276e-02, 6.1149e-03, 1.6405e-02, 1.5643e-02, + 1.5060e-02, -5.1235e-03, -2.2824e-02, -1.3752e-02, -1.5742e-02, + 2.4032e-02, -2.1782e-03, -1.3158e-02, 3.9482e-03, 3.2267e-02, + -2.2632e-03, 1.2055e-02, 4.4731e-02, 1.8271e-02, -1.1486e-02, + 1.7836e-02, 1.7886e-03, -2.4020e-02, 2.6064e-02, -2.2122e-04, + 1.8643e-02, -2.9808e-02, -6.1845e-03, -4.4464e-03, 8.8374e-04, + 1.5268e-02, 1.7205e-03, 5.7832e-02, -1.7486e-02, 1.1897e-02, + 5.8081e-02, 1.7667e-02, -7.7282e-03, 1.4036e-02, -1.4936e-03, + 6.0635e-04, 1.6124e-03, -1.6916e-02, -1.1239e-02, 1.8497e-02, + 1.2334e-03, -2.0706e-02, 3.2959e-03, 2.9186e-02, 3.7506e-02, + 1.2037e-02, -1.4903e-02, 8.5606e-03, 3.4136e-03, 1.1850e-02, + -7.4782e-03, 5.3924e-03, -2.4772e-02, 2.6840e-02, -2.7656e-02, + -3.2637e-02, -1.2779e-02, 1.0730e-02, 1.4096e-03, 3.1572e-02, + 7.8976e-04, 3.1674e-02, 8.5333e-03, -1.2679e-02, 1.1176e-02, + -2.0446e-02, 1.8628e-02, -4.0158e-02, -2.3358e-02, -2.2504e-02, + -2.8759e-02, -1.4597e-02, -8.5879e-03, 1.0550e-02, -3.5556e-02, + -1.9046e-02, -1.9159e-02, -2.2703e-02, -7.2056e-03, 4.2380e-02, + -9.7475e-03, -2.4754e-02, 1.3992e-03, -1.0411e-02, 1.5708e-02, + -8.2899e-03, -6.4856e-03, 1.6359e-02, -5.1969e-04, -5.0958e-03, + -4.1232e-02, 2.7349e-03, -1.7723e-02, 1.3388e-02, 2.2776e-03, + -2.0786e-02, -1.8082e-02, -2.4866e-03, 2.2141e-02, 6.9998e-03, + -5.5714e-03, 2.1088e-02, 5.8745e-03, 1.2788e-02, 4.2977e-03, + 5.8631e-03, -1.8121e-02, 1.9242e-03, 2.3622e-02, 1.4917e-02, + -5.3198e-03, -3.9222e-02, -2.4697e-02, 9.1218e-03, -1.0711e-02, + 1.0268e-02, 1.5148e-02, -4.4508e-02, 4.6783e-03, 2.8093e-03, + 9.1253e-03, -7.3281e-03, 1.0114e-03, -9.2369e-04, 1.4841e-02, + 2.2642e-02, 2.3675e-02, 1.3902e-02, -5.6343e-03, 1.4851e-02, + -9.5169e-03, -3.1721e-02, 1.6696e-02, 2.9285e-02, -1.4090e-02, + 2.1128e-02, 4.8656e-02, 3.8431e-02, -3.5470e-02, -4.8230e-03, + -1.6513e-02, 4.1917e-02, 8.9090e-03, -1.4022e-04, 4.0182e-03, + 7.1723e-03, 3.1419e-02, -4.8508e-03, 1.7768e-03, -7.3688e-03, + 3.4637e-03, -2.3227e-02, 3.9606e-05, -2.4731e-02, -1.3640e-02, + -5.1718e-03, 2.6662e-02, -1.2871e-02, -1.6009e-02, -5.3720e-03, + 2.7397e-04, -3.4016e-03, 2.6429e-02, 3.8069e-02, 1.0929e-02, + -1.0620e-02, 1.2165e-02, -2.6018e-02, 1.6021e-02, 4.0644e-02, + -8.0898e-03, -3.5198e-02, -1.9602e-02, 2.4986e-02, -5.8400e-03, + 3.2070e-02, -1.8265e-02, -5.4518e-03, 2.8195e-02, 5.5598e-02, + -3.9959e-02, 1.5521e-02, -2.8416e-02, 3.1130e-02, -1.0038e-02, + 2.1522e-02, -1.1654e-02, 2.2382e-02, -5.4467e-03, -2.2840e-02, + 2.7036e-03, -4.4607e-02, -4.1953e-02, 2.0079e-02, -5.0121e-03, + -1.7495e-02, 4.4070e-03, 3.7400e-04, 1.0899e-02, 1.7008e-02, + -1.6307e-02, -1.9986e-02, -2.3865e-02, -2.5618e-02, -2.9981e-02, + -2.7230e-03, 2.7079e-02, 5.2920e-03, 2.1069e-02, -2.5896e-02, + -1.6256e-02, -1.4182e-03, 1.1829e-02, 1.0360e-02, 2.8883e-02, + -6.8762e-03, 1.4032e-02, -4.3389e-03]], requires_grad=True) \ No newline at end of file diff --git a/test/datasets/test_dataset.py b/test/datasets/test_dataset.py new file mode 100644 index 0000000..c4f64a8 --- /dev/null +++ b/test/datasets/test_dataset.py @@ -0,0 +1,58 @@ +import datasets +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer +from trl import SFTTrainer, DataCollatorForCompletionOnlyLM +import random +from tqdm import tqdm + +# path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/models/cmrc2018_trial.json" +# dataset = load_dataset("json", data_files=[path], field="data", split="train") +# tokenizer = AutoTokenizer.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") +# def preprocess_function(example): +# import pdb; pdb.set_trace() +# model_inputs = tokenizer(example["content"], max_length=512, truncation=True) +# labels = tokenizer(example["title"], max_length=32, truncation=True) +# # label就是title编码的结果 +# model_inputs["labels"] = labels["input_ids"] +# return model_inputs +# processed_datasets = dataset.map(preprocess_function) + +dataset = load_dataset("/mnt/pfs-guan-ssai/nlu/wanghanzi/data/alpaca_20k") +train_dataset = dataset['train'] + + +for i in tqdm(range(1, len(train_dataset))): + import pdb; pdb.set_trace() + + idx = random.randint(0,i) + memory = train_dataset[idx] + memory_text = f"Instruction: {memory['instruction']}\n Answer: {memory['output']} \n" + train_dataset[i]['text'] = f"{memory_text} Instruction:{train_dataset[i]['instruction']}" + + +import pdb; pdb.set_trace() + + +model_path = "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/opt_350m" +model = AutoModelForCausalLM.from_pretrained(model_path) +tokenizer = AutoTokenizer.from_pretrained(model_path) + + +def formatting_prompts_func(example): + import pdb; pdb.set_trace() + output_texts = [] + for i in range(len(example['instruction'])): + text = f"### Question: {example['instruction'][i]}\n ### Answer: {example['output'][i]}" + output_texts.append(text) + return output_texts + +response_template = " ### Answer:" +collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) + +trainer = SFTTrainer( + model, + train_dataset=train_dataset, + formatting_func=formatting_prompts_func, + data_collator=collator, +) +trainer.train() \ No newline at end of file diff --git a/test/datasets/test_eval_gqa_dataset.py b/test/datasets/test_eval_gqa_dataset.py new file mode 100644 index 0000000..540d782 --- /dev/null +++ b/test/datasets/test_eval_gqa_dataset.py @@ -0,0 +1,222 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +# import wandb + +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank +from minigpt4.common.logger import setup_logger +from minigpt4.common.registry import registry +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + # parser.add_argument("-f", help="jupyter notebook") + parser.add_argument( + "--cfg-path", + default="/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/projects/qformer_moe_vicuna/eval/mix_vqa_coco_vicuna_eval.yaml", + help="path to configuration file.") + parser.add_argument( + "--gpu-id", + type=int, + default=4, + help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +# Test About Building Task +# build config +device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu") +cfg = Config(parse_args()) +setup_seeds(cfg) +print(cfg._convert_node_to_json(cfg.config)) + +setup_logger() +cfg.pretty_print() + +task = tasks.setup_task(cfg) +datasets = task.build_datasets(cfg) + +def get_runner_class(cfg): + """ + Get runner class from config. Default to epoch-based runner. + """ + runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) + + return runner_cls + +job_id = now() +# model = task.build_model(cfg) +model = None +# task.build_tensorboard(cfg) + +runner = get_runner_class(cfg)( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets +) + + +""" + Dataset & DataLoader Setup +""" +from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset +import webdataset as wds +import logging + +batch_sizes = {dataset_name: getattr(runner.config.datasets_cfg, dataset_name).batch_size + for dataset_name in runner.datasets.keys()} +print(batch_sizes) + + +datasets, batch_sizes = reorg_datasets_by_split(runner.datasets, batch_sizes) +runner.datasets = datasets +# self.datasets = concat_datasets(datasets) +print(runner.datasets.keys()) # dict_keys(['train', 'val', 'test']) + +# print dataset statistics after concatenation/chaining +for split_name in runner.datasets: + if isinstance(runner.datasets[split_name], tuple) or isinstance( + runner.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [wds.DataPipeline, ChainDataset] + else 0 + for d in runner.datasets[split_name] + ] + ) + + else: + if hasattr(runner.datasets[split_name], "__len__"): + # a single map-style dataset + num_records = len(runner.datasets[split_name]) + else: + # a single wds.DataPipeline + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + +split_names = sorted(runner.datasets.keys()) + +datasets = [runner.datasets[split] for split in split_names] +batch_sizes = [batch_sizes[split] for split in split_names] +is_trains = [split in runner.train_splits for split in split_names] + +print("split_names: ",split_names) +print("is_trains: ",is_trains) +print("batch sizes: ", batch_sizes) + + +collate_fns = [] +for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + +dataloaders = runner.create_loaders( + datasets=datasets, + num_workers=runner.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, +) +_dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + +loader = _dataloaders['train'] +loader = _dataloaders['val'] +loader_idx = random.choices(range(len(loader.loaders)), loader.ratios, k=1)[0] +print(loader_idx) +next(loader.loaders[loader_idx])['question_id'], next(loader.loaders[loader_idx])['source'] + + + + +import json +file = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/gqa_943k_raw_QformerMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1115/20231115234/result/val_vqa_result_gqa.json" +data = json.load(open(file, "r")) +cnt = 0 +for i in range(len(data)): + d = data[i] + # if d['gt_ans'] in d['pred_ans'].lower(): + if d['gt_ans'] in d['answer'].lower(): + cnt += 1 + else: + print(d) + if i == 100: + break + +print(cnt/len(data)) + + +# from minigpt4.common.vqa_tools.vqa import VQA +# from minigpt4.common.vqa_tools.vqa_eval import VQAEval + +# split = 'val' +# source = 'vqav2' +# print(task.anno_files[split][source],task.ques_files[split][source]) +# vqa = VQA(task.anno_files[split][source], task.ques_files[split][source]) + +# result_file = '/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_coco_gqa_1610k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1101/20231031224/result/val_vqa_result_2.json' +# print('result_file: ',result_file) +# vqa_result = vqa.loadRes(resFile=result_file, quesFile=task.ques_files[split][source]) + +# vqa_scorer = VQAEval(vqa, vqa_result, n=2) +# vqa_scorer.evaluate() +# overall_acc = vqa_scorer.accuracy["overall"] +# perAnswerType = vqa_scorer.accuracy["perAnswerType"] + +# print(overall_acc) +# print(perAnswerType) diff --git a/test/datasets/test_llava_mix_dataset.py b/test/datasets/test_llava_mix_dataset.py new file mode 100644 index 0000000..6d9c5d1 --- /dev/null +++ b/test/datasets/test_llava_mix_dataset.py @@ -0,0 +1,99 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +# import wandb + +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank +from minigpt4.common.logger import setup_logger +from minigpt4.common.registry import registry +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + # parser.add_argument("-f", help="jupyter notebook") + parser.add_argument( + "--cfg-path", + default="/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_data_balance_finetuned.yaml", + help="path to configuration file.") + parser.add_argument( + "--gpu-id", + type=int, + default=5, + help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + +def get_runner_class(cfg): + """ + Get runner class from config. Default to epoch-based runner. + """ + runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) + + return runner_cls + +# Test About Building Task +# build config +device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") +cfg = Config(parse_args()) +setup_seeds(cfg) +print(cfg._convert_node_to_json(cfg.config)) + +setup_logger() +cfg.pretty_print() + +task = tasks.setup_task(cfg) +datasets = task.build_datasets(cfg) + +job_id = now() +# model = task.build_model(cfg) +# model = None +task.build_tensorboard(cfg) + +runner = get_runner_class(cfg)( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets +) + +data_loader = runner.train_loader +data_loader = runner.dataloaders['val'] diff --git a/test/datasets/test_pretrain_task_dataset.py b/test/datasets/test_pretrain_task_dataset.py new file mode 100644 index 0000000..b4b60cc --- /dev/null +++ b/test/datasets/test_pretrain_task_dataset.py @@ -0,0 +1,202 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +# import wandb + +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank +from minigpt4.common.logger import setup_logger +from minigpt4.common.registry import registry +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + # parser.add_argument("-f", help="jupyter notebook") + parser.add_argument( + "--cfg-path", + # default="/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/projects/minigpt/train/minigptv2_finetune_vqa.yaml", + default="/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/projects/prompt_moe/train/mix_coco_gqa_prompt_moe_post_blip2.yaml", + help="path to configuration file.") + parser.add_argument( + "--gpu-id", + type=int, + default=5, + help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + + +# Test About Building Task +# build config +device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") +cfg = Config(parse_args()) +setup_seeds(cfg) +print(cfg._convert_node_to_json(cfg.config)) + +setup_logger() +cfg.pretty_print() + +task = tasks.setup_task(cfg) +datasets = task.build_datasets(cfg) + +def get_runner_class(cfg): + """ + Get runner class from config. Default to epoch-based runner. + """ + runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) + + return runner_cls + +job_id = now() +model = task.build_model(cfg) +# model = None +# task.build_tensorboard(cfg) + +runner = get_runner_class(cfg)( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets +) + + +# from minigpt4.common.vqa_tools.vqa import VQA +# from minigpt4.common.vqa_tools.vqa_eval import VQAEval + +# split = 'val' +# source = 'vqav2' +# print(task.anno_files[split][source],task.ques_files[split][source]) +# vqa = VQA(task.anno_files[split][source], task.ques_files[split][source]) + +# result_file = '/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/flant5xxl/prompt_moe/mix_coco_gqa_1610k_raw_postMoE_train_qf_train_qt_linear_gate_20ex_top2_3loss_textinqf_epo3_1101/20231031224/result/val_vqa_result_2.json' +# print('result_file: ',result_file) +# vqa_result = vqa.loadRes(resFile=result_file, quesFile=task.ques_files[split][source]) + +# vqa_scorer = VQAEval(vqa, vqa_result, n=2) +# vqa_scorer.evaluate() +# overall_acc = vqa_scorer.accuracy["overall"] +# perAnswerType = vqa_scorer.accuracy["perAnswerType"] + +# print(overall_acc) +# print(perAnswerType) + + + +from minigpt4.datasets.data_utils import concat_datasets, reorg_datasets_by_split, ChainDataset +import webdataset as wds +import logging + +batch_sizes = {dataset_name: getattr(runner.config.datasets_cfg, dataset_name).batch_size + for dataset_name in runner.datasets.keys()} +print(batch_sizes) + + +datasets, batch_sizes = reorg_datasets_by_split(runner.datasets, batch_sizes) +runner.datasets = datasets +# self.datasets = concat_datasets(datasets) +print(runner.datasets.keys()) # dict_keys(['train', 'val', 'test']) + +# print dataset statistics after concatenation/chaining +for split_name in runner.datasets: + if isinstance(runner.datasets[split_name], tuple) or isinstance( + runner.datasets[split_name], list + ): + # mixed wds.DataPipeline and torch.utils.data.Dataset + num_records = sum( + [ + len(d) + if not type(d) in [wds.DataPipeline, ChainDataset] + else 0 + for d in runner.datasets[split_name] + ] + ) + + else: + if hasattr(runner.datasets[split_name], "__len__"): + # a single map-style dataset + num_records = len(runner.datasets[split_name]) + else: + # a single wds.DataPipeline + num_records = -1 + logging.info( + "Only a single wds.DataPipeline dataset, no __len__ attribute." + ) + + if num_records >= 0: + logging.info( + "Loaded {} records for {} split from the dataset.".format( + num_records, split_name + ) + ) + +split_names = sorted(runner.datasets.keys()) + +datasets = [runner.datasets[split] for split in split_names] +batch_sizes = [batch_sizes[split] for split in split_names] +is_trains = [split in runner.train_splits for split in split_names] + +print("split_names: ",split_names) +print("is_trains: ",is_trains) +print("batch sizes: ", batch_sizes) + + +collate_fns = [] +for dataset in datasets: + if isinstance(dataset, tuple) or isinstance(dataset, list): + collate_fns.append([getattr(d, "collater", None) for d in dataset]) + else: + collate_fns.append(getattr(dataset, "collater", None)) + +dataloaders = runner.create_loaders( + datasets=datasets, + num_workers=runner.config.run_cfg.num_workers, + batch_sizes=batch_sizes, + is_trains=is_trains, + collate_fns=collate_fns, +) +_dataloaders = {k: v for k, v in zip(split_names, dataloaders)} + + +loader = _dataloaders['train'] +loader_idx = random.choices(range(len(loader.loaders)), loader.ratios, k=1)[0] +print(loader_idx) +next(loader.loaders[loader_idx])['question_id'], next(loader.loaders[loader_idx])['source'] + diff --git a/test/models/test_moe_model.py b/test/models/test_moe_model.py new file mode 100644 index 0000000..b8154a7 --- /dev/null +++ b/test/models/test_moe_model.py @@ -0,0 +1,181 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") +from minigpt4.models.QformerMoE import ( + BertConfig, + BertMoELMHeadModel +) +vision_width = 1408 +cross_attention_freq = 2 +num_query_token = 32 + +# init_QformerMoE +moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") +moe_encoder_config.encoder_width = vision_width +# insert cross-attention layer every other block +moe_encoder_config.add_cross_attention = True +moe_encoder_config.cross_attention_freq = cross_attention_freq +moe_encoder_config.query_length = num_query_token +moe_encoder_config.moebert_expert_num = 4 +moe_encoder_config.moebert_route_method = "gate-sentence" +moe_encoder_config.moe_topk = 2 +moe_encoder_config.moebert_load_balance = 0.1 +moe_encoder_config.moebert_share_importance = 512 # TODO: meaning? +MoEQformer = BertMoELMHeadModel.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config +) + + +""" + Compare Qformer & QformerMoE +""" +# blip2_qformer +# calculate parameters +from minigpt4.models import load_model +model = load_model("blip2", "pretrain") +model.QformerMoE, model.query_tokens_moe = model.init_QformerMoE( + num_query_token, model.visual_encoder.num_features, cross_attention_freq +) +model.Qformer, model.query_tokens = model.init_Qformer( + num_query_token, model.visual_encoder.num_features, cross_attention_freq +) +state_dict = model.Qformer.state_dict() +for name, param in model.Qformer.named_parameters(): + if "_query" in name: + key_orig = name.replace("_query", "") + param.data.copy_(state_dict[key_orig]) + if "10" in name: + print(name) + + +""" + blip2_t5_qformer_moe + Calculate Num Parameters +""" +import torch +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") +from minigpt4.models import model_zoo +from minigpt4.models import load_model +print(model_zoo) +device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") +model = load_model("blip2_t5_qformer_moe", "flant5xxl", device=device) + +num_parameters=0 +for n, p in model.Qformer.named_parameters(): + if not p.requires_grad: + continue # frozen weights + if "11.experts.experts" in n: + print(n) + num_parameters += p.data.nelement() +print(num_parameters) # 23,619,840 +# total trainable parameter: 415,631,104 + +num_parameters=0 +for n, p in model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + num_parameters += p.data.nelement() +print(num_parameters) # 23,619,840 +# total trainable parameter: 415,631,104 + + +num_parameters=0 +for n, p in model.named_parameters(): + if not p.requires_grad: + continue # frozen weights + if 'Qformer.bert.encoder.layer.6.crossattention' in n: + num_parameters += p.data.nelement() + # if 'Qformer.bert.encoder.layer.11.output' in n: + # num_parameters += p.data.nelement() +print(num_parameters) + + +""" + forward +""" +import torch +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") +from minigpt4.models import load_model +device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") +model = load_model("blip2", "pretrain", device=device) + +samples = { + 'q_input':["What is around the open window?", # n23181 + "Is the ground blue or brown?", # n168412 + "What color are the pants?", # n446242 + "What is the airplane flying above?"], # n414992 + 'llm_input':["What is around the open window?", # n23181 + "Is the ground blue or brown?", # n168412 + "What color are the pants?", # n446242 + "What is the airplane flying above?"], # n414992 + 'text_output':["drapes", + "brown", + "red", + "ocean" + ], + 'image': torch.randn(4, 3, 224, 224).half().to(device) + # 'image': torch.randn(4, 3, 336, 336).half().to(device) +} + +Qformer, query_tokens = model.init_QformerMoE( + num_query_token=32, + vision_width=1408, + moebert_expert_num=5, + moebert_route_method="gate-sentence", + moebert_load_balance=0.1, + moe_topk=2, + cross_attention_freq=2 + ) +Qformer = Qformer.to(device) + +def maybe_autocast(device, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = device != torch.device("cpu") + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + +image = samples["image"] +with maybe_autocast(device): + image_embeds = model.ln_vision(model.visual_encoder(image)) +image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) +bs = image.size(0) +query_tokens = query_tokens.expand(bs, -1, -1).to(device) + + + +# image = samples["image"] +# image_atts = torch.ones(4, 257).to(device) +# image_embeds = torch.randn(4, 257, 1408).to(device) +# bz = image_embeds.shape[0] +# query_tokens = query_tokens.expand(bz, -1, -1).to(device) + +text_Qformer = model.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=32, + return_tensors="pt", +).to(image.device) +query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) +Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1).to(device) + +query_output = Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, +) + diff --git a/test/models/test_moe_qformer_model.py b/test/models/test_moe_qformer_model.py new file mode 100644 index 0000000..6b2049c --- /dev/null +++ b/test/models/test_moe_qformer_model.py @@ -0,0 +1,225 @@ +""" + Copyright (c) 2022, salesforce.com, inc. + All rights reserved. + SPDX-License-Identifier: BSD-3-Clause + For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause +""" + +import argparse +import os +import random + +import numpy as np +import torch +import torch.backends.cudnn as cudnn +# import wandb + +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") + +import minigpt4.tasks as tasks +from minigpt4.common.config import Config +from minigpt4.common.dist_utils import get_rank +from minigpt4.common.logger import setup_logger +from minigpt4.common.registry import registry +from minigpt4.common.utils import now + +# imports modules for registration +from minigpt4.datasets.builders import * +from minigpt4.models import * +from minigpt4.processors import * +from minigpt4.runners import * +from minigpt4.tasks import * + + +def parse_args(): + parser = argparse.ArgumentParser(description="Demo") + # parser.add_argument("-f", help="jupyter notebook") + parser.add_argument( + "--cfg-path", + default="/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/minigpt4/projects/qformer_moe_post_vicuna/train/mix_qformer_moe_post_blip2_vicuna7b_data_balance_finetuned.yaml", + help="path to configuration file.") + parser.add_argument( + "--gpu-id", + type=int, + default=5, + help="specify the gpu to load the model.") + parser.add_argument( + "--options", + nargs="+", + help="override some settings in the used config, the key-value pair " + "in xxx=yyy format will be merged into config file (deprecate), " + "change to --cfg-options instead.", + ) + args = parser.parse_args() + return args + + +def setup_seeds(config): + seed = config.run_cfg.seed + get_rank() + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + cudnn.benchmark = False + cudnn.deterministic = True + +def get_runner_class(cfg): + """ + Get runner class from config. Default to epoch-based runner. + """ + runner_cls = registry.get_runner_class(cfg.run_cfg.get("runner", "runner_base")) + + return runner_cls + +# Test About Building Task +# build config +device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") +cfg = Config(parse_args()) +setup_seeds(cfg) +print(cfg._convert_node_to_json(cfg.config)) + +setup_logger() +cfg.pretty_print() + +task = tasks.setup_task(cfg) +datasets = task.build_datasets(cfg) + +job_id = now() +model = task.build_model(cfg) +# model = None +task.build_tensorboard(cfg) + +runner = get_runner_class(cfg)( + cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets +) + +for name, param in model.named_parameters(): + if param.requires_grad == True: + if name == 'Qformer.bert.encoder.layer.10.experts.experts.2.intermediate_query.dense.weight': + print(name) + print(param) + if name == 'Qformer.bert.encoder.layer.10.intermediate_query.dense.weight': + print(name) + print(param) + + +for name, param in model.named_parameters(): + if param.requires_grad == False: + if 'Qformer' in name and '10' in name: + print(name) + +for key in m1['model'].keys(): + if 'Qformer' in key: + print(key) + +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") +from minigpt4.models.QformerMoE import ( + BertConfig, + BertMoELMHeadModel +) +vision_width = 1408 +cross_attention_freq = 2 +num_query_token = 32 + +# init_QformerMoE +moe_encoder_config = BertConfig.from_pretrained("/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased") +moe_encoder_config.encoder_width = vision_width +# insert cross-attention layer every other block +moe_encoder_config.add_cross_attention = True +moe_encoder_config.cross_attention_freq = cross_attention_freq +moe_encoder_config.query_length = num_query_token +moe_encoder_config.moebert_expert_num = 4 +moe_encoder_config.moebert_route_method = "gate-sentence" +moe_encoder_config.moe_topk = 2 +moe_encoder_config.moebert_load_balance = 0.1 +moe_encoder_config.moebert_share_importance = 512 # TODO: meaning? +MoEQformer = BertMoELMHeadModel.from_pretrained( + "/mnt/pfs-guan-ssai/nlu/wanghanzi/models/bert-base-uncased", config=moe_encoder_config +) + + + +""" + forward +""" +import torch +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") +from minigpt4.models import load_model +device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") +model = load_model("blip2", "pretrain", device=device) + +samples = { + 'q_input':["What is around the open window?", # n23181 + "Is the ground blue or brown?", # n168412 + "What color are the pants?", # n446242 + "What is the airplane flying above?"], # n414992 + 'llm_input':["What is around the open window?", # n23181 + "Is the ground blue or brown?", # n168412 + "What color are the pants?", # n446242 + "What is the airplane flying above?"], # n414992 + 'text_output':["drapes", + "brown", + "red", + "ocean" + ], + 'image': torch.randn(4, 3, 224, 224).half().to(device) +} + +Qformer, query_tokens = model.init_QformerMoE( + num_query_token=32, + vision_width=1408, + moebert_expert_num=5, + moebert_route_method="gate-sentence", + moebert_load_balance=0.1, + moe_topk=2, + cross_attention_freq=2 + ) +Qformer = Qformer.to(device) + +def maybe_autocast(device, dtype=torch.float16): + # if on cpu, don't use autocast + # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 + enable_autocast = device != torch.device("cpu") + if enable_autocast: + return torch.cuda.amp.autocast(dtype=dtype) + else: + return contextlib.nullcontext() + +image = samples["image"] +with maybe_autocast(device): + image_embeds = model.ln_vision(model.visual_encoder(image)) +image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) +bs = image.size(0) +query_tokens = query_tokens.expand(bs, -1, -1).to(device) + + + +# image = samples["image"] +# image_atts = torch.ones(4, 257).to(device) +# image_embeds = torch.randn(4, 257, 1408).to(device) +# bz = image_embeds.shape[0] +# query_tokens = query_tokens.expand(bz, -1, -1).to(device) + +text_Qformer = model.tokenizer( + samples["q_input"], + padding='longest', + truncation=True, + max_length=32, + return_tensors="pt", +).to(image.device) +query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) +Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1).to(device) + +query_output = Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, +) + diff --git a/test/models/test_vicuna_blip2.py b/test/models/test_vicuna_blip2.py new file mode 100644 index 0000000..f3d94b1 --- /dev/null +++ b/test/models/test_vicuna_blip2.py @@ -0,0 +1,114 @@ +import sys +sys.path.append("/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE") +from minigpt4.models import load_model +import torch +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +# model = load_model("blip2_vicuna_instruct", "vicuna7b_pretrain", device="cpu") +# model = load_model("blip2_vicuna_instruct", "vicuna7b_instruct", device=device) +model = load_model("blip2_vicuna_instruct", "vicuna7b_qfmoe_post", device=device) + + +use_nucleus_sampling=False +num_beams=5 +max_length=256 +min_length=1 +top_p=0.9 +repetition_penalty=1.5 +length_penalty=1 +num_captions=1 +temperature=1 + + +if "prompt" in samples.keys(): + prompt = samples["prompt"] +else: + prompt = model.prompt + +image = samples["image"] + +bs = image.size(0) + +if isinstance(prompt, str): + prompt = [prompt] * bs +else: + assert len(prompt) == bs, "The number of prompts must be equal to the batch size." + +# For TextCaps +if "ocr_tokens" in samples.keys() and "{}" in prompt[0]: + prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)] + +query_tokens = model.query_tokens.expand(bs, -1, -1) + +text_Qformer = model.tokenizer( + prompt, + padding='longest', + truncation=True, + max_length=model.max_txt_len, + return_tensors="pt", +).to(image.device) +query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) +Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1) + +with model.maybe_autocast(): + image_embeds = model.ln_vision(model.visual_encoder(image)) +image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) + +query_output = model.Qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, +) +inputs_llm = model.llm_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:]) +atts_llm = torch.ones(inputs_llm.size()[:-1], dtype=torch.long).to(image.device) + +llm_tokens = model.llm_tokenizer( + prompt, + padding="longest", + return_tensors="pt" +).to(image.device) + +inputs_embeds = model.llm_model.get_input_embeddings()(llm_tokens.input_ids) +inputs_embeds = torch.cat([inputs_llm, inputs_embeds], dim=1) +attention_mask = torch.cat([atts_llm, llm_tokens.attention_mask], dim=1) +print(inputs_embeds.shape) + +inputs_embeds0 = inputs_embeds[:2] +attention_mask0 = attention_mask[:2] + +outputs = model.llm_model.generate( + inputs_embeds=inputs_embeds0, # torch.Size([4, 41, 4096]) + attention_mask=attention_mask0, # torch.Size([4, 41]) + do_sample=use_nucleus_sampling, + top_p=top_p, + temperature=temperature, + num_beams=num_beams, + max_length=max_length, + min_length=min_length, + # eos_token_id=self.eos_token_id, + repetition_penalty=repetition_penalty, + length_penalty=length_penalty, + num_return_sequences=num_captions, +) + +from PIL import Image +image = "/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/examples_v2/KFC-20-for-20-Nuggets.jpg" +raw_image = Image.open(image).convert('RGB') +image = model.vis_processor(raw_image).unsqueeze(0).to("cpu") +sample = {'text_input': ["What is around the open window?", "Is the ground blue or brown?"], + 'image':torch.randn(2, 3, 224, 224).to("cpu")} + +samples = { + 'text_input':["What is around the open window?", # n23181 + "Is the ground blue or brown?", # n168412 + "What color are the pants?", # n446242 + "What is the airplane flying above?"], # n414992 + 'image': torch.randn(4, 3, 224, 224).to("cpu") +} + + +for key in mo['model'].keys(): + if 'Qformer' not in key: + print(key) \ No newline at end of file diff --git a/test/test.txt b/test/test.txt new file mode 100644 index 0000000..fe2b17a --- /dev/null +++ b/test/test.txt @@ -0,0 +1,92 @@ +datasets: + multitask_conversation: + batch_size: 2 + sample_ratio: 50 + + llava_conversation: + batch_size: 2 + sample_ratio: 30 + + unnatural_instruction: + batch_size: 1 + sample_ratio: 10 + + refvg: + batch_size: 6 + sample_ratio: 40 + + llava_detail: + batch_size: 4 + sample_ratio: 20 + + llava_reason: + batch_size: 4 + sample_ratio: 80 + + + flickr_grounded_caption: + batch_size: 2 + sample_ratio: 80 + + flickr_CaptionToPhrase: + batch_size: 2 + sample_ratio: 80 + + flickr_ObjectToPhrase: + batch_size: 2 + sample_ratio: 80 + + coco_caption: + batch_size: 6 + sample_ratio: 10 + + + textcaps_caption: + batch_size: 6 + sample_ratio: 30 + + refcoco: + batch_size: 6 + sample_ratio: 25 + + + refcocop: + batch_size: 6 + sample_ratio: 25 + + refcocog: + batch_size: 6 + sample_ratio: 25 + + invrefcoco: + batch_size: 6 + sample_ratio: 10 + + invrefcocop: + batch_size: 6 + sample_ratio: 10 + + invrefcocog: + batch_size: 6 + sample_ratio: 10 + + + coco_vqa: + batch_size: 6 + sample_ratio: 15 + + ok_vqa: + batch_size: 6 + sample_ratio: 8 + + aok_vqa: + batch_size: 6 + sample_ratio: 12 + + gqa: + batch_size: 6 + sample_ratio: 50 + + ocrvqa: + batch_size: 6 + sample_ratio: 30 \ No newline at end of file diff --git a/test1.txt b/test1.txt new file mode 100644 index 0000000..336dc71 --- /dev/null +++ b/test1.txt @@ -0,0 +1,109 @@ +from torchviz import make_dot +dot = make_dot(query_output.last_hidden_state, params=dict(self.Qformer.bert.named_parameters())) +log_dir = '/mnt/pfs-guan-ssai/nlu/wanghanzi/multimodal/PromptMoE/' +dot.render(filename="Post_Route_Universal_PromptMoE_RawProb_backward_graph", directory=log_dir, format="pdf") + + +# Pre-Prompt-MoE +model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[6].experts.experts[0].dense1.weight.grad +model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[0].dense1.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[1].dense1.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[2].dense1.weight.grad + + +model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[9].intermediate_query.dense.weight +model.Qformer.bert.encoder.layer[9].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[10].intermediate.dense.weight.grad +model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad + +model.Qformer.bert.encoder.layer[10].intermediate_query.dense.weight +model.Qformer.bert.encoder.layer[10].experts.experts[2].dense1.weight +model.Qformer.bert.encoder.layer[10].experts.experts[1].dense1.weight +model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight +model.Qformer.bert.encoder.layer[10].intermediate_query.dense.weight == model.Qformer.bert.encoder.layer[10].experts.experts[0].dense1.weight + +# Pre-MoE gate-sentence +# model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad 不更新 + +# Pre-MoE gate-token +# 正常更新 + +# Post-MoE gate-sentence +model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad +# model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad 正常更新 +# model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad 全是0/-0 +# model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad 全是0/-0 + +# Route-MoE +# Pre-MoE 算的beam_scores有问题 + +# Post-Route 会更新多个expert的参数;会更新gate的参数 +# Layer 6 更新了两个expert的参数 (layer 6 layer 8) +# model.Qformer.bert.encoder.layer[11].intermediate.dense.weight.grad 是0?都是0 +# model.Qformer.bert.encoder.layer[11].output.dense.weight.grad + +model.Qformer.bert.encoder.layer[6].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[6].experts.experts[0].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[6].experts.experts[1].intermediate_query.dense.weight.grad + +model.Qformer.bert.encoder.layer[7].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[7].experts.experts[0].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[7].experts.experts[1].intermediate_query.dense.weight.grad + +model.Qformer.bert.encoder.layer[8].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[0].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[8].experts.experts[1].intermediate_query.dense.weight.grad + +model.Qformer.bert.encoder.layer[9].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[9].experts.experts[0].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[9].experts.experts[1].intermediate_query.dense.weight.grad + +model.Qformer.bert.encoder.layer[10].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[10].experts.experts[0].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[10].experts.experts[1].intermediate_query.dense.weight.grad + +model.Qformer.bert.encoder.layer[11].experts.gate.weight.grad +model.Qformer.bert.encoder.layer[11].experts.experts[0].intermediate_query.dense.weight.grad +model.Qformer.bert.encoder.layer[11].experts.experts[1].intermediate_query.dense.weight.grad + + +(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.10.experts.experts.0.dense1.weight'] +[Parameter containing: +tensor([[-0.0328, 0.0414, 0.0010, ..., -0.0068, 0.0244, 0.0587], + [ 0.0120, 0.0458, 0.0171, ..., -0.0439, -0.0107, -0.0397], + [ 0.0239, 0.0191, -0.0145, ..., 0.0008, -0.0067, 0.0090], + ..., + [ 0.0174, -0.0465, -0.0106, ..., -0.0095, 0.0153, -0.0195], + [-0.0151, -0.0082, -0.0320, ..., -0.0016, -0.0232, -0.0147], + [ 0.0142, -0.0286, 0.0161, ..., -0.0160, -0.0306, -0.0272]], + device='cuda:0', requires_grad=True)] +(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.8.experts.experts.0.dense1.weight'] +[Parameter containing: +tensor([[ 0.0024, 0.0218, -0.0186, ..., -0.0178, -0.0067, 0.0820], + [-0.0759, -0.0002, -0.0548, ..., 0.0292, 0.0531, 0.0779], + [-0.0220, -0.0037, -0.0520, ..., -0.0426, -0.0261, -0.0357], + ..., + [-0.0448, 0.0471, 0.0133, ..., -0.0062, -0.0217, -0.0203], + [ 0.0532, 0.0197, 0.0320, ..., -0.0010, -0.0838, 0.0682], + [ 0.0284, 0.0038, -0.0007, ..., -0.0305, 0.0296, 0.0056]], + device='cuda:0', requires_grad=True)] +(Pdb) [p for n, p in self.model.named_parameters() if n == 'Qformer.bert.encoder.layer.6.experts.experts.0.dense1.weight'] +[Parameter containing: +tensor([[ 6.5176e-02, -4.6473e-02, -2.7396e-02, ..., 2.1774e-03, + 6.1457e-02, 1.9180e-03], + [ 7.3707e-03, 6.1392e-02, -2.7108e-02, ..., 4.0778e-02, + -1.9791e-02, -1.1612e-02], + [ 2.1193e-02, -3.8323e-02, -6.0238e-02, ..., -1.4539e-02, + 9.2965e-02, 3.9153e-02], + ..., + [ 5.3203e-03, -1.7276e-02, -3.2191e-02, ..., -1.6435e-02, + -1.8553e-02, -2.8158e-02], + [-6.9853e-02, 9.2719e-03, -1.8895e-03, ..., -2.6425e-02, + 1.4880e-03, 3.4505e-02], + [-1.2168e-03, 3.7038e-02, 4.8047e-02, ..., -3.4523e-03, + -1.3030e-05, -1.4778e-02]], device='cuda:0', requires_grad=True)] \ No newline at end of file diff --git a/test_text_cap.py b/test_text_cap.py new file mode 100644 index 0000000..67dcb01 --- /dev/null +++ b/test_text_cap.py @@ -0,0 +1,5 @@ +from minigpt4.common.caption_tools.caption_utils import coco_caption_eval, textcaps_caption_eval +result_file = "/mnt/pfs-guan-ssai/nlu/wanghanzi/experiments/blip2/vicuna7b/qformer_moe_uni_route/mix_coco_gqa_ao_cocap_textcap_raw_QformerMoE_Post_Route_Universal_lnout_lr5e5_3ex3b_2loss_005_top6layer_textinqf_12epo_0317/20240317165/result/val_vqa_result_text_cap.json" +annotaion_file = "/mnt/pfs-guan-ssai/nlu/wanghanzi/data/TextCap/TextCaps_0.1_val.json" +eval = textcaps_caption_eval(annotaion_file, result_file) +print(eval.eval.item()) \ No newline at end of file diff --git a/train.py b/train.py index 4dead8e..944cebd 100644 --- a/train.py +++ b/train.py @@ -88,15 +88,17 @@ def main(): task = tasks.setup_task(cfg) datasets = task.build_datasets(cfg) model = task.build_model(cfg) + task.build_tensorboard(cfg) if cfg.run_cfg.wandb_log: wandb.login() - wandb.init(project="minigptv", name=cfg.run_cfg.job_name) + wandb.init(project="promptmoe", name=cfg.run_cfg.job_name) wandb.watch(model) runner = get_runner_class(cfg)( cfg=cfg, job_id=job_id, task=task, model=model, datasets=datasets ) + runner.train() diff --git a/train_configs/minigptv2_finetune.yaml b/train_configs/minigptv2_finetune.yaml deleted file mode 100644 index 114d7e9..0000000 --- a/train_configs/minigptv2_finetune.yaml +++ /dev/null @@ -1,294 +0,0 @@ -model: - arch: minigpt_v2 - model_type: pretrain - max_txt_len: 1024 - image_size: 448 - end_sym: "" - llama_model: "/path/to/llama_checkpoint" - ckpt: "/path/to/pretrained_checkpoint" - use_grad_checkpoint: True - chat_template: True - lora_r: 64 - lora_alpha: 16 - -datasets: - multitask_conversation: - batch_size: 2 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 50 - - llava_conversation: - batch_size: 2 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 30 - - unnatural_instruction: - batch_size: 1 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 10 - - - refvg: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 40 - - llava_detail: - batch_size: 4 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 20 - - llava_reason: - batch_size: 4 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 80 - - - flickr_grounded_caption: - batch_size: 2 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 80 - - flickr_CaptionToPhrase: - batch_size: 2 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 80 - - flickr_ObjectToPhrase: - batch_size: 2 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 80 - - coco_caption: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 10 - - - textcaps_caption: # - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 30 - - refcoco: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 25 - - - refcocop: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 25 - - refcocog: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 25 - - - - invrefcoco: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 10 - - invrefcocop: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 10 - - invrefcocog: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 10 - - - coco_vqa: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 15 - - ok_vqa: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 8 - - aok_vqa: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 12 - - gqa: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 50 - - ocrvqa: - batch_size: 6 - vis_processor: - train: - name: "blip2_image_train" - image_size: 448 - text_processor: - train: - name: "blip_caption" - sample_ratio: 30 - - -run: - task: image_text_pretrain - # optimizer - lr_sched: "linear_warmup_cosine_lr" - init_lr: 1e-5 - min_lr: 8e-5 - warmup_lr: 1e-6 - - weight_decay: 0.05 - max_epoch: 50 - num_workers: 6 - warmup_steps: 1000 - iters_per_epoch: 1000 - - seed: 42 - output_dir: "/path/to/save_checkpoint" - - amp: True - resume_ckpt_path: null - - evaluate: False - train_splits: ["train"] - - device: "cuda" - world_size: 1 - dist_url: "env://" - distributed: True - - wandb_log: True - job_name: minigptv2_finetune \ No newline at end of file