72 lines
1.9 KiB
Python
72 lines
1.9 KiB
Python
|
import torch
|
||
|
from src.models import rl_model, trading_model, transformer_model
|
||
|
|
||
|
def test_rl_model():
|
||
|
# Create mock input data
|
||
|
input_data = torch.rand((1, 10))
|
||
|
|
||
|
# Instantiate the model
|
||
|
model = rl_model.RLModel()
|
||
|
|
||
|
# Forward pass
|
||
|
outputs = model(input_data)
|
||
|
|
||
|
# Check output dimensions
|
||
|
assert outputs.size() == torch.Size([1, model.output_size])
|
||
|
|
||
|
# Check that the model is on the correct device
|
||
|
assert outputs.device == model.device
|
||
|
|
||
|
def test_trading_model():
|
||
|
# Create mock input data
|
||
|
input_data = torch.rand((1, 10))
|
||
|
|
||
|
# Instantiate the model
|
||
|
model = trading_model.TradingModel()
|
||
|
|
||
|
# Forward pass
|
||
|
outputs = model(input_data)
|
||
|
|
||
|
# Check output dimensions
|
||
|
assert outputs.size() == torch.Size([1, model.output_size])
|
||
|
|
||
|
# Check that the model is on the correct device
|
||
|
assert outputs.device == model.device
|
||
|
|
||
|
def test_transformer_model():
|
||
|
# Create mock input data
|
||
|
input_ids = torch.randint(0, 100, (1, 20))
|
||
|
attention_mask = torch.ones((1, 20))
|
||
|
|
||
|
# Instantiate the model
|
||
|
model = transformer_model.TransformerModel()
|
||
|
|
||
|
# Forward pass
|
||
|
outputs = model(input_ids, attention_mask)
|
||
|
|
||
|
# Check output dimensions
|
||
|
assert outputs.size() == torch.Size([1, 20, model.hidden_size])
|
||
|
|
||
|
# Check that the model is on the correct device
|
||
|
assert outputs.device == model.device
|
||
|
|
||
|
def test_model_save_load():
|
||
|
# Instantiate the model
|
||
|
model = transformer_model.TransformerModel()
|
||
|
|
||
|
# Save the model
|
||
|
model.save_pretrained('test_model')
|
||
|
|
||
|
# Load the model
|
||
|
loaded_model = transformer_model.TransformerModel.from_pretrained('test_model')
|
||
|
|
||
|
# Check that the loaded model has the same parameters as the original model
|
||
|
for p1, p2 in zip(model.parameters(), loaded_model.parameters()):
|
||
|
assert torch.all(p1.eq(p2))
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
test_rl_model()
|
||
|
test_trading_model()
|
||
|
test_transformer_model()
|
||
|
test_model_save_load()
|