# %% Import required packages import torch import matplotlib.pyplot as plt from src.models.transformer_model import TransformerModel from src.models.rl_model import RLModel from models.trading_model import TradingAgent from src.evaluation.evaluate import evaluate_trading_agent from src.data.data_preprocessing import load_processed_data # %% Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # %% Load processed data data = load_processed_data('./data/processed/processed_data.csv') # %% Initialize models transformer_model = TransformerModel().to(device) rl_model = RLModel().to(device) trading_agent = TradingAgent(transformer_model, rl_model) # %% Load model weights transformer_model.load_state_dict(torch.load('./models/transformer_model.pth')) rl_model.load_state_dict(torch.load('./models/rl_model.pth')) # %% Evaluate the trading agent evaluation_results = evaluate_trading_agent(trading_agent, data) # %% Display evaluation results print("Total Profit: ", evaluation_results['total_profit']) print("Total Trades Made: ", evaluation_results['total_trades']) print("Successful Trades: ", evaluation_results['successful_trades']) # %% Plot profit over time plt.plot(evaluation_results['profits_over_time']) plt.xlabel('Time') plt.ylabel('Profit') plt.title('Profit Over Time') plt.show() # %% Plot trade outcomes plt.bar(['Successful Trades', 'Unsuccessful Trades'], [evaluation_results['successful_trades'], evaluation_results['total_trades'] - evaluation_results['successful_trades']]) plt.xlabel('Trade Outcome') plt.ylabel('Number of Trades') plt.title('Trade Outcomes') plt.show() # %% Save evaluation results with open('./logs/evaluation_results.txt', 'w') as f: for key, value in evaluation_results.items(): if key != 'profits_over_time': # Don't save the profits over time in the text file f.write(f'{key}: {value}\n') # %% Save profits over time as a CSV file import pandas as pd profits_over_time_df = pd.DataFrame(evaluation_results['profits_over_time'], columns=['Profit']) profits_over_time_df.to_csv('./logs/profits_over_time.csv', index=False)