60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
# %% Import required packages
|
|
import torch
|
|
import matplotlib.pyplot as plt
|
|
from src.models.transformer_model import TransformerModel
|
|
from src.models.rl_model import RLModel
|
|
from models.trading_model import TradingAgent
|
|
from src.evaluation.evaluate import evaluate_trading_agent
|
|
from src.data.data_preprocessing import load_processed_data
|
|
|
|
# %% Set device
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# %% Load processed data
|
|
data = load_processed_data('./data/processed/processed_data.csv')
|
|
|
|
# %% Initialize models
|
|
transformer_model = TransformerModel().to(device)
|
|
rl_model = RLModel().to(device)
|
|
trading_agent = TradingAgent(transformer_model, rl_model)
|
|
|
|
# %% Load model weights
|
|
transformer_model.load_state_dict(torch.load('./models/transformer_model.pth'))
|
|
rl_model.load_state_dict(torch.load('./models/rl_model.pth'))
|
|
|
|
# %% Evaluate the trading agent
|
|
evaluation_results = evaluate_trading_agent(trading_agent, data)
|
|
|
|
# %% Display evaluation results
|
|
print("Total Profit: ", evaluation_results['total_profit'])
|
|
print("Total Trades Made: ", evaluation_results['total_trades'])
|
|
print("Successful Trades: ", evaluation_results['successful_trades'])
|
|
|
|
# %% Plot profit over time
|
|
plt.plot(evaluation_results['profits_over_time'])
|
|
plt.xlabel('Time')
|
|
plt.ylabel('Profit')
|
|
plt.title('Profit Over Time')
|
|
plt.show()
|
|
|
|
# %% Plot trade outcomes
|
|
plt.bar(['Successful Trades', 'Unsuccessful Trades'],
|
|
[evaluation_results['successful_trades'], evaluation_results['total_trades'] - evaluation_results['successful_trades']])
|
|
plt.xlabel('Trade Outcome')
|
|
plt.ylabel('Number of Trades')
|
|
plt.title('Trade Outcomes')
|
|
plt.show()
|
|
|
|
# %% Save evaluation results
|
|
with open('./logs/evaluation_results.txt', 'w') as f:
|
|
for key, value in evaluation_results.items():
|
|
if key != 'profits_over_time': # Don't save the profits over time in the text file
|
|
f.write(f'{key}: {value}\n')
|
|
|
|
# %% Save profits over time as a CSV file
|
|
import pandas as pd
|
|
|
|
profits_over_time_df = pd.DataFrame(evaluation_results['profits_over_time'], columns=['Profit'])
|
|
profits_over_time_df.to_csv('./logs/profits_over_time.csv', index=False)
|
|
|