import torch from src.models import rl_model, transformer_model from src.agents import Agent from src.training import train from src.utils import seed_everything, scale_data def test_training(): # Set a seed for reproducibility seed_everything(42) # Create mock data input_data = torch.rand((100, 10)) target_data = torch.rand((100, 10)) # Scale the data input_data, _ = scale_data(input_data) target_data, _ = scale_data(target_data) # Instantiate the models transformer = transformer_model.TransformerModel() rl = rl_model.RLModel() # Instantiate the agent agent = Agent(transformer_model=transformer, rl_model=rl) # Define loss function and optimizer criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(agent.parameters()) # Train the agent train(agent, input_data, target_data, criterion, optimizer) # Check that model weights have been updated initial_weights = list(agent.parameters())[0].detach().clone() assert not torch.all(torch.eq(initial_weights, list(agent.parameters())[0])) if __name__ == "__main__": test_training()