import pytest import torch from src.models import transformer_model def test_transformer_model(): # Create mock input data input_ids = torch.randint(0, 100, (1, 20)) attention_mask = torch.ones((1, 20)) # Instantiate the model model = transformer_model.TransformerModel() # Forward pass outputs = model(input_ids, attention_mask) # Check output dimensions assert outputs.size() == torch.Size([1, 20, model.hidden_size]) # Check that the model is on the correct device assert outputs.device == model.device def test_model_save_load(): # Instantiate the model model = transformer_model.TransformerModel() # Save the model model.save_pretrained('test_model') # Load the model loaded_model = transformer_model.TransformerModel.from_pretrained('test_model') # Check that the loaded model has the same parameters as the original model for p1, p2 in zip(model.parameters(), loaded_model.parameters()): assert torch.all(p1.eq(p2))