auto-trading/tests/test_transformer_model.py

35 lines
1016 B
Python

import pytest
import torch
from src.models import transformer_model
def test_transformer_model():
# Create mock input data
input_ids = torch.randint(0, 100, (1, 20))
attention_mask = torch.ones((1, 20))
# Instantiate the model
model = transformer_model.TransformerModel()
# Forward pass
outputs = model(input_ids, attention_mask)
# Check output dimensions
assert outputs.size() == torch.Size([1, 20, model.hidden_size])
# Check that the model is on the correct device
assert outputs.device == model.device
def test_model_save_load():
# Instantiate the model
model = transformer_model.TransformerModel()
# Save the model
model.save_pretrained('test_model')
# Load the model
loaded_model = transformer_model.TransformerModel.from_pretrained('test_model')
# Check that the loaded model has the same parameters as the original model
for p1, p2 in zip(model.parameters(), loaded_model.parameters()):
assert torch.all(p1.eq(p2))