39 lines
1.1 KiB
Python
39 lines
1.1 KiB
Python
import torch
|
|
from src.models import rl_model, transformer_model
|
|
from src.agents import Agent
|
|
from src.training import train
|
|
from src.utils import seed_everything, scale_data
|
|
|
|
def test_training():
|
|
# Set a seed for reproducibility
|
|
seed_everything(42)
|
|
|
|
# Create mock data
|
|
input_data = torch.rand((100, 10))
|
|
target_data = torch.rand((100, 10))
|
|
|
|
# Scale the data
|
|
input_data, _ = scale_data(input_data)
|
|
target_data, _ = scale_data(target_data)
|
|
|
|
# Instantiate the models
|
|
transformer = transformer_model.TransformerModel()
|
|
rl = rl_model.RLModel()
|
|
|
|
# Instantiate the agent
|
|
agent = Agent(transformer_model=transformer, rl_model=rl)
|
|
|
|
# Define loss function and optimizer
|
|
criterion = torch.nn.MSELoss()
|
|
optimizer = torch.optim.Adam(agent.parameters())
|
|
|
|
# Train the agent
|
|
train(agent, input_data, target_data, criterion, optimizer)
|
|
|
|
# Check that model weights have been updated
|
|
initial_weights = list(agent.parameters())[0].detach().clone()
|
|
assert not torch.all(torch.eq(initial_weights, list(agent.parameters())[0]))
|
|
|
|
if __name__ == "__main__":
|
|
test_training()
|