auto-trading/tests/test_models.py

72 lines
1.9 KiB
Python

import torch
from src.models import rl_model, trading_model, transformer_model
def test_rl_model():
# Create mock input data
input_data = torch.rand((1, 10))
# Instantiate the model
model = rl_model.RLModel()
# Forward pass
outputs = model(input_data)
# Check output dimensions
assert outputs.size() == torch.Size([1, model.output_size])
# Check that the model is on the correct device
assert outputs.device == model.device
def test_trading_model():
# Create mock input data
input_data = torch.rand((1, 10))
# Instantiate the model
model = trading_model.TradingModel()
# Forward pass
outputs = model(input_data)
# Check output dimensions
assert outputs.size() == torch.Size([1, model.output_size])
# Check that the model is on the correct device
assert outputs.device == model.device
def test_transformer_model():
# Create mock input data
input_ids = torch.randint(0, 100, (1, 20))
attention_mask = torch.ones((1, 20))
# Instantiate the model
model = transformer_model.TransformerModel()
# Forward pass
outputs = model(input_ids, attention_mask)
# Check output dimensions
assert outputs.size() == torch.Size([1, 20, model.hidden_size])
# Check that the model is on the correct device
assert outputs.device == model.device
def test_model_save_load():
# Instantiate the model
model = transformer_model.TransformerModel()
# Save the model
model.save_pretrained('test_model')
# Load the model
loaded_model = transformer_model.TransformerModel.from_pretrained('test_model')
# Check that the loaded model has the same parameters as the original model
for p1, p2 in zip(model.parameters(), loaded_model.parameters()):
assert torch.all(p1.eq(p2))
if __name__ == "__main__":
test_rl_model()
test_trading_model()
test_transformer_model()
test_model_save_load()