auto-trading/tests/test_training.py

39 lines
1.1 KiB
Python

import torch
from src.models import rl_model, transformer_model
from src.agents import Agent
from src.training import train
from src.utils import seed_everything, scale_data
def test_training():
# Set a seed for reproducibility
seed_everything(42)
# Create mock data
input_data = torch.rand((100, 10))
target_data = torch.rand((100, 10))
# Scale the data
input_data, _ = scale_data(input_data)
target_data, _ = scale_data(target_data)
# Instantiate the models
transformer = transformer_model.TransformerModel()
rl = rl_model.RLModel()
# Instantiate the agent
agent = Agent(transformer_model=transformer, rl_model=rl)
# Define loss function and optimizer
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(agent.parameters())
# Train the agent
train(agent, input_data, target_data, criterion, optimizer)
# Check that model weights have been updated
initial_weights = list(agent.parameters())[0].detach().clone()
assert not torch.all(torch.eq(initial_weights, list(agent.parameters())[0]))
if __name__ == "__main__":
test_training()