54 lines
2.2 KiB
Python
54 lines
2.2 KiB
Python
# %% Import required packages
|
|
import torch
|
|
from src.models.transformer_model import TransformerModel
|
|
from src.models.rl_model import RLModel
|
|
from models.trading_model import TradingAgent
|
|
from src.training.train import train_transformer, train_rl
|
|
from src.data.data_preprocessing import load_processed_data
|
|
|
|
# %% Set device
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# %% Load processed data
|
|
data = load_processed_data('./data/processed/processed_data.csv')
|
|
|
|
# %% Initialize models
|
|
transformer_model = TransformerModel().to(device)
|
|
rl_model = RLModel().to(device)
|
|
trading_agent = TradingAgent(transformer_model, rl_model)
|
|
|
|
# %% Test out different model architectures
|
|
# Note: This is a simplified example. In a real project, you might want to try out different model architectures and hyperparameters.
|
|
|
|
# Try out a different Transformer model configuration
|
|
transformer_model_2 = TransformerModel(config={'num_layers': 4, 'num_heads': 8}).to(device)
|
|
trading_agent_2 = TradingAgent(transformer_model_2, rl_model)
|
|
|
|
# %% Test out different hyperparameters
|
|
# Note: This is a simplified example. In a real project, you might want to try out different hyperparameters and training strategies.
|
|
|
|
transformer_model_hyperparams_2 = {
|
|
"epochs": 5,
|
|
"batch_size": 64,
|
|
"learning_rate": 0.001,
|
|
}
|
|
|
|
# Train the model with the new hyperparameters
|
|
train_transformer(transformer_model_2, data, transformer_model_hyperparams_2)
|
|
|
|
# %% Evaluate the models
|
|
# Note: This is a simplified example. In a real project, you might want to evaluate the models on a validation set and compare their performance.
|
|
|
|
# You might have a function like this one to compute the validation loss
|
|
def compute_validation_loss(model, validation_data):
|
|
# Compute the validation loss for the model
|
|
# This function is not implemented in this example
|
|
pass
|
|
|
|
# Compute the validation loss for the original model and the new model
|
|
validation_loss = compute_validation_loss(trading_agent, validation_data)
|
|
validation_loss_2 = compute_validation_loss(trading_agent_2, validation_data)
|
|
|
|
print(f'Validation loss for the original model: {validation_loss}')
|
|
print(f'Validation loss for the new model: {validation_loss_2}')
|