auto-trading/notebooks/model_prototyping.py

54 lines
2.2 KiB
Python

# %% Import required packages
import torch
from src.models.transformer_model import TransformerModel
from src.models.rl_model import RLModel
from models.trading_model import TradingAgent
from src.training.train import train_transformer, train_rl
from src.data.data_preprocessing import load_processed_data
# %% Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# %% Load processed data
data = load_processed_data('./data/processed/processed_data.csv')
# %% Initialize models
transformer_model = TransformerModel().to(device)
rl_model = RLModel().to(device)
trading_agent = TradingAgent(transformer_model, rl_model)
# %% Test out different model architectures
# Note: This is a simplified example. In a real project, you might want to try out different model architectures and hyperparameters.
# Try out a different Transformer model configuration
transformer_model_2 = TransformerModel(config={'num_layers': 4, 'num_heads': 8}).to(device)
trading_agent_2 = TradingAgent(transformer_model_2, rl_model)
# %% Test out different hyperparameters
# Note: This is a simplified example. In a real project, you might want to try out different hyperparameters and training strategies.
transformer_model_hyperparams_2 = {
"epochs": 5,
"batch_size": 64,
"learning_rate": 0.001,
}
# Train the model with the new hyperparameters
train_transformer(transformer_model_2, data, transformer_model_hyperparams_2)
# %% Evaluate the models
# Note: This is a simplified example. In a real project, you might want to evaluate the models on a validation set and compare their performance.
# You might have a function like this one to compute the validation loss
def compute_validation_loss(model, validation_data):
# Compute the validation loss for the model
# This function is not implemented in this example
pass
# Compute the validation loss for the original model and the new model
validation_loss = compute_validation_loss(trading_agent, validation_data)
validation_loss_2 = compute_validation_loss(trading_agent_2, validation_data)
print(f'Validation loss for the original model: {validation_loss}')
print(f'Validation loss for the new model: {validation_loss_2}')