35 lines
1016 B
Python
35 lines
1016 B
Python
import pytest
|
|
import torch
|
|
from src.models import transformer_model
|
|
|
|
def test_transformer_model():
|
|
# Create mock input data
|
|
input_ids = torch.randint(0, 100, (1, 20))
|
|
attention_mask = torch.ones((1, 20))
|
|
|
|
# Instantiate the model
|
|
model = transformer_model.TransformerModel()
|
|
|
|
# Forward pass
|
|
outputs = model(input_ids, attention_mask)
|
|
|
|
# Check output dimensions
|
|
assert outputs.size() == torch.Size([1, 20, model.hidden_size])
|
|
|
|
# Check that the model is on the correct device
|
|
assert outputs.device == model.device
|
|
|
|
def test_model_save_load():
|
|
# Instantiate the model
|
|
model = transformer_model.TransformerModel()
|
|
|
|
# Save the model
|
|
model.save_pretrained('test_model')
|
|
|
|
# Load the model
|
|
loaded_model = transformer_model.TransformerModel.from_pretrained('test_model')
|
|
|
|
# Check that the loaded model has the same parameters as the original model
|
|
for p1, p2 in zip(model.parameters(), loaded_model.parameters()):
|
|
assert torch.all(p1.eq(p2))
|