# %% Import required packages import torch from src.models.transformer_model import TransformerModel from src.models.rl_model import RLModel from models.trading_model import TradingAgent from src.training.train import train_transformer, train_rl from src.data.data_preprocessing import load_processed_data # %% Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # %% Load processed data data = load_processed_data('./data/processed/processed_data.csv') # %% Initialize models transformer_model = TransformerModel().to(device) rl_model = RLModel().to(device) trading_agent = TradingAgent(transformer_model, rl_model) # %% Test out different model architectures # Note: This is a simplified example. In a real project, you might want to try out different model architectures and hyperparameters. # Try out a different Transformer model configuration transformer_model_2 = TransformerModel(config={'num_layers': 4, 'num_heads': 8}).to(device) trading_agent_2 = TradingAgent(transformer_model_2, rl_model) # %% Test out different hyperparameters # Note: This is a simplified example. In a real project, you might want to try out different hyperparameters and training strategies. transformer_model_hyperparams_2 = { "epochs": 5, "batch_size": 64, "learning_rate": 0.001, } # Train the model with the new hyperparameters train_transformer(transformer_model_2, data, transformer_model_hyperparams_2) # %% Evaluate the models # Note: This is a simplified example. In a real project, you might want to evaluate the models on a validation set and compare their performance. # You might have a function like this one to compute the validation loss def compute_validation_loss(model, validation_data): # Compute the validation loss for the model # This function is not implemented in this example pass # Compute the validation loss for the original model and the new model validation_loss = compute_validation_loss(trading_agent, validation_data) validation_loss_2 = compute_validation_loss(trading_agent_2, validation_data) print(f'Validation loss for the original model: {validation_loss}') print(f'Validation loss for the new model: {validation_loss_2}')