diff --git a/README.md b/README.md index b96be0e..fe114ce 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,6 @@ # AutoTradingSystem -Trial project for deep learning trading model with ChatGPT-4 - - -This project is an automatic trading system based on a Transformer and Reinforcement Learning hybrid model. - -Trial project for deep learning trading model with ChatGPT-4 +This project is an automatic trading system based on a Transformer and Reinforcement Learning hybrid model. It is built using HuggingFace's Transformers library and ChatGPT-4. ## Setup @@ -14,59 +9,64 @@ Trial project for deep learning trading model with ChatGPT-4 ## Structure -trading-system/ +my_project/ │ ├── data/ -│ ├── raw/ # Raw data files -│ └── processed/ # Processed data files +│ ├── raw/ # Raw data goes here +│ └── processed/ # Processed data goes here │ -├── models/ # Trained models and model checkpoints -│ -├── logs/ # Training logs, evaluation results, etc. -│ -├── notebooks/ # Jupyter notebooks -│ ├── data_exploration.ipynb -│ ├── model_training.ipynb -│ ├── model_evaluation.ipynb -│ └── demo.ipynb +├── notebooks/ # Jupyter notebooks for exploratory analysis and prototyping +│ ├── data_exploration.ipynb # Notebook for exploring the data +│ ├── data_preprocessing.ipynb # Notebook for preprocessing the data +│ ├── model_prototyping.ipynb # Notebook for prototyping models +│ ├── model_training.ipynb # Notebook for training models +│ ├── model_evaluation.ipynb # Notebook for evaluating models +│ └── hyperparameter_tuning.ipynb # Notebook for tuning hyperparameters │ ├── src/ -│ ├── data/ # Data-related modules -│ │ ├── __init__.py -│ │ ├── data_collection.py -│ │ └── data_preprocessing.py -│ │ -│ ├── models/ # Model-related modules -│ │ ├── __init__.py -│ │ ├── transformer_model.py -│ │ ├── rl_model.py -│ │ └── trading_agent.py -│ │ -│ ├── training/ # Training-related modules -│ │ ├── __init__.py -│ │ └── train.py -│ │ -│ ├── evaluation/ # Evaluation-related modules -│ │ ├── __init__.py -│ │ └── evaluate.py -│ │ -│ ├── utils/ # Utility modules -│ │ ├── __init__.py -│ │ ├── metrics.py -│ │ └── utils.py -│ │ -│ └── main.py # Main entry point for the project +│ ├── agents/ # Agents module for HuggingFace's Transformers +│ │ ├── init.py +│ │ └── agent_example.py # An example of a Transformer Agent +│ │ +│ ├── data/ # Data handling module +│ │ ├── init.py +│ │ ├── data_collection.py # Data collection scripts +│ │ └── data_preprocessing.py # Data preprocessing scripts +│ │ +│ ├── models/ # Models module +│ │ ├── init.py +│ │ ├── transformer_model.py # Transformer model script +│ │ ├── rl_model.py # RL model script +│ │ └── trading_model.py # Trading model script +│ │ +│ ├── training/ # Training module +│ │ ├── init.py +│ │ └── train.py # Training scripts +│ │ +│ ├── evaluation/ # Evaluation module +│ │ ├── init.py +│ │ └── evaluate.py # Evaluation scripts +│ │ +│ ├── utils/ # Utility module +│ │ ├── init.py +│ │ ├── utils.py # Utility scripts (including seeding function) +│ │ └── metrics.py # Metrics computation scripts +│ │ +│ └── main.py # Main script to run the model │ -├── tests/ # Test-related modules -│ ├── __init__.py -│ ├── test_data_collection.py -│ ├── test_data_preprocessing.py -│ ├── test_transformer_model.py -│ ├── test_rl_model.py -│ ├── test_trading_model.py -│ └── test_metrics.py +├── config/ # Configuration files +│ ├── transformer.json +│ ├── rl.json +│ └── training.json │ -├── requirements.txt # Required Python packages +├── tests/ # Test module +│ ├── init.py +│ ├── test_data.py # Tests for data collection and preprocessing +│ ├── test_models.py # Tests for model creation +│ ├── test_training.py # Tests for model training +│ ├── test_evaluation.py # Tests for model evaluation +│ └── test_utils.py # Tests for utility functions │ -└── README.md # Project documentation - +├── .gitignore # Specifies which files should be ignored by Git +├── README.md # Project description and instructions +└── requirements.txt # List of Python dependencies diff --git a/config/rl.json b/config/rl.json new file mode 100644 index 0000000..7e9b5f9 --- /dev/null +++ b/config/rl.json @@ -0,0 +1,13 @@ +{ + "algorithm": "DQN", + "state_size": 100, + "action_size": 3, + "hidden_size": 64, + "learning_rate": 0.001, + "gamma": 0.99, + "epsilon_start": 1.0, + "epsilon_end": 0.01, + "epsilon_decay": 0.995, + "target_update": 10, + "memory_size": 10000 +} diff --git a/config/training.json b/config/training.json new file mode 100644 index 0000000..ae23843 --- /dev/null +++ b/config/training.json @@ -0,0 +1,7 @@ +{ + "epochs": 100, + "batch_size": 64, + "optimizer": "Adam", + "learning_rate": 0.001, + "loss_function": "CrossEntropyLoss" +} diff --git a/config/transformer.json b/config/transformer.json new file mode 100644 index 0000000..c01d17c --- /dev/null +++ b/config/transformer.json @@ -0,0 +1,12 @@ +{ + "model_name": "bert-base-uncased", + "num_layers": 12, + "hidden_size": 768, + "num_attention_heads": 12, + "intermediate_size": 3072, + "hidden_dropout_prob": 0.1, + "attention_probs_dropout_prob": 0.1, + "max_position_embeddings": 512, + "type_vocab_size": 2, + "initializer_range": 0.02 +} diff --git a/notebooks/data_exploration.py b/notebooks/data_exploration.py index cc5475e..efc5765 100644 --- a/notebooks/data_exploration.py +++ b/notebooks/data_exploration.py @@ -5,18 +5,22 @@ import seaborn as sns # %% Load data # Assume that we have a CSV file in the processed data folder -data = pd.read_csv('./data/processed/processed_data.csv') +data = pd.read_csv('../data/processed/processed_data.csv') # %% Display the first few rows of the data +# This gives a snapshot of the data and its structure. print(data.head()) # %% Display data summary +# This gives statistical details of the data like mean, standard deviation, etc. print(data.describe()) # %% Check for missing values +# Missing values can affect the performance of the model and should be handled appropriately. print(data.isnull().sum()) # %% Visualize the closing prices +# Plotting the data helps in understanding the trend and seasonality in the data. plt.figure(figsize=(14, 7)) plt.plot(data['Close']) plt.title('Closing Prices Over Time') @@ -25,18 +29,21 @@ plt.ylabel('Price') plt.show() # %% Display the distribution of daily returns +# This can give an idea about the volatility of the stock. daily_returns = data['Close'].pct_change().dropna() sns.histplot(daily_returns, bins=50, kde=True) plt.title('Distribution of Daily Returns') plt.show() # %% Display correlation between different features +# Correlation can indicate if there are any dependent relationships between the variables. correlation_matrix = data.corr() sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm') plt.title('Correlation Matrix of Features') plt.show() # %% Display a scatter plot of volume vs closing price +# Scatter plot can show the relationship between two variables. plt.scatter(data['Volume'], data['Close']) plt.title('Volume vs Closing Price') plt.xlabel('Volume') @@ -44,8 +51,26 @@ plt.ylabel('Closing Price') plt.show() # %% Display time series decomposition if applicable -# You might need to install and import statsmodels for this -# from statsmodels.tsa.seasonal import seasonal_decompose -# decomposed = seasonal_decompose(data['Close'], model='multiplicative', period=252) # Assume that period is 252 for trading days in a year -# decomposed.plot() -# plt.show() +# Time series decomposition can help in understanding the trend, seasonality, and noise in the data. +# Please note that this requires statsmodels library. +from statsmodels.tsa.seasonal import seasonal_decompose +decomposed = seasonal_decompose(data['Close'], model='multiplicative', period=252) # Assume that period is 252 for trading days in a year +decomposed.plot() +plt.show() + +# %% Display moving averages +# Moving averages can help in understanding the trend in the data over different time periods. +data['Close'].rolling(window=7).mean().plot(label='7 Day Average') +data['Close'].rolling(window=30).mean().plot(label='30 Day Average') +data['Close'].rolling(window=90).mean().plot(label='90 Day Average') +plt.legend() +plt.title('Moving Averages of Closing Prices') +plt.show() + +# %% Display Autocorrelation plot +# Autocorrelation can show if the data is random or if there is a pattern. +# Please note that this requires pandas.plotting library. +from pandas.plotting import autocorrelation_plot +autocorrelation_plot(data['Close']) +plt.title('Autocorrelation of Closing Prices') +plt.show() diff --git a/notebooks/data_preprocessing.py b/notebooks/data_preprocessing.py new file mode 100644 index 0000000..041087f --- /dev/null +++ b/notebooks/data_preprocessing.py @@ -0,0 +1,38 @@ +# %% Import required packages +import pandas as pd +from sklearn.preprocessing import MinMaxScaler +from sklearn.model_selection import train_test_split + +# %% Load data +# Assume that we have a CSV file in the raw data folder +data = pd.read_csv('../data/raw/raw_data.csv') + +# %% Display the first few rows of the data +# It's always a good idea to take a look at the data before starting preprocessing. +print(data.head()) + +# %% Handle missing values +# This will depend on your specific dataset. Here we'll simply remove rows with missing values. +data = data.dropna() + +# %% Convert date column to datetime +# If your dataset includes a date column, it's a good idea to convert it to datetime format for time series analysis. +data['Date'] = pd.to_datetime(data['Date']) + +# %% Set date as index +# For time series analysis, it can be useful to set the date column as the index of the DataFrame. +data = data.set_index('Date') + +# %% Normalize data (optional) +# If your models require normalized data, you can use MinMaxScaler or another normalization technique. +scaler = MinMaxScaler() +data_normalized = pd.DataFrame(scaler.fit_transform(data), columns=data.columns, index=data.index) + +# %% Split data into training and testing sets +# It's important to split your data into training and testing sets to evaluate the performance of your models. +train_data, test_data = train_test_split(data, test_size=0.2, shuffle=False) + +# %% Save processed data +# Finally, save your processed data for further use. +train_data.to_csv('../data/processed/train_data.csv') +test_data.to_csv('../data/processed/test_data.csv') diff --git a/notebooks/hyperparameter_tuning.py b/notebooks/hyperparameter_tuning.py new file mode 100644 index 0000000..7eb1834 --- /dev/null +++ b/notebooks/hyperparameter_tuning.py @@ -0,0 +1,51 @@ +# %% Import required packages +import torch +from src.models.transformer_model import TransformerModel +from src.models.rl_model import RLModel +from models.trading_model import TradingAgent +from src.training.train import train_transformer, train_rl +from src.evaluation.evaluate import evaluate_trading_agent +from src.data.data_preprocessing import load_processed_data +from sklearn.model_selection import ParameterGrid +import json + +# %% Set device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# %% Load processed data +data = load_processed_data('./data/processed/processed_data.csv') + +# %% Define hyperparameters grid +param_grid = { + 'learning_rate': [0.001, 0.01], + 'batch_size': [32, 64], + 'epochs': [10, 50] +} +param_grid = list(ParameterGrid(param_grid)) + +# %% Initialize models +transformer_model = TransformerModel().to(device) +rl_model = RLModel().to(device) +trading_agent = TradingAgent(transformer_model, rl_model) + +# %% Hyperparameters tuning +results = [] + +for params in param_grid: + # Train Transformer Model + train_transformer(transformer_model, data, params) + # Train RL Model + train_rl(trading_agent, data, params) + # Evaluate the trading agent + evaluation_results = evaluate_trading_agent(trading_agent, data) + # Append results + results.append({ + 'params': params, + 'evaluation_results': evaluation_results + }) + print(f"Params: {params}, Evaluation Results: {evaluation_results}") + +# %% Save tuning results +with open('./logs/hyperparameter_tuning_results.json', 'w') as f: + json.dump(results, f) + diff --git a/notebooks/model_evaluation.py b/notebooks/model_evaluation.py index 6d2228d..2ef4f21 100644 --- a/notebooks/model_evaluation.py +++ b/notebooks/model_evaluation.py @@ -1,8 +1,9 @@ # %% Import required packages import torch +import matplotlib.pyplot as plt from src.models.transformer_model import TransformerModel from src.models.rl_model import RLModel -from src.models.trading_agent import TradingAgent +from models.trading_model import TradingAgent from src.evaluation.evaluate import evaluate_trading_agent from src.data.data_preprocessing import load_processed_data @@ -22,14 +23,37 @@ transformer_model.load_state_dict(torch.load('./models/transformer_model.pth')) rl_model.load_state_dict(torch.load('./models/rl_model.pth')) # %% Evaluate the trading agent -trading_agent_results = evaluate_trading_agent(trading_agent, data) +evaluation_results = evaluate_trading_agent(trading_agent, data) # %% Display evaluation results -print("Total Profit: ", trading_agent_results['total_profit']) -print("Total Trades Made: ", trading_agent_results['total_trades']) -print("Successful Trades: ", trading_agent_results['successful_trades']) +print("Total Profit: ", evaluation_results['total_profit']) +print("Total Trades Made: ", evaluation_results['total_trades']) +print("Successful Trades: ", evaluation_results['successful_trades']) + +# %% Plot profit over time +plt.plot(evaluation_results['profits_over_time']) +plt.xlabel('Time') +plt.ylabel('Profit') +plt.title('Profit Over Time') +plt.show() + +# %% Plot trade outcomes +plt.bar(['Successful Trades', 'Unsuccessful Trades'], + [evaluation_results['successful_trades'], evaluation_results['total_trades'] - evaluation_results['successful_trades']]) +plt.xlabel('Trade Outcome') +plt.ylabel('Number of Trades') +plt.title('Trade Outcomes') +plt.show() # %% Save evaluation results with open('./logs/evaluation_results.txt', 'w') as f: - for key, value in trading_agent_results.items(): - f.write(f'{key}: {value}\n') + for key, value in evaluation_results.items(): + if key != 'profits_over_time': # Don't save the profits over time in the text file + f.write(f'{key}: {value}\n') + +# %% Save profits over time as a CSV file +import pandas as pd + +profits_over_time_df = pd.DataFrame(evaluation_results['profits_over_time'], columns=['Profit']) +profits_over_time_df.to_csv('./logs/profits_over_time.csv', index=False) + diff --git a/notebooks/model_prototyping.py b/notebooks/model_prototyping.py new file mode 100644 index 0000000..4a321ac --- /dev/null +++ b/notebooks/model_prototyping.py @@ -0,0 +1,53 @@ +# %% Import required packages +import torch +from src.models.transformer_model import TransformerModel +from src.models.rl_model import RLModel +from models.trading_model import TradingAgent +from src.training.train import train_transformer, train_rl +from src.data.data_preprocessing import load_processed_data + +# %% Set device +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# %% Load processed data +data = load_processed_data('./data/processed/processed_data.csv') + +# %% Initialize models +transformer_model = TransformerModel().to(device) +rl_model = RLModel().to(device) +trading_agent = TradingAgent(transformer_model, rl_model) + +# %% Test out different model architectures +# Note: This is a simplified example. In a real project, you might want to try out different model architectures and hyperparameters. + +# Try out a different Transformer model configuration +transformer_model_2 = TransformerModel(config={'num_layers': 4, 'num_heads': 8}).to(device) +trading_agent_2 = TradingAgent(transformer_model_2, rl_model) + +# %% Test out different hyperparameters +# Note: This is a simplified example. In a real project, you might want to try out different hyperparameters and training strategies. + +transformer_model_hyperparams_2 = { + "epochs": 5, + "batch_size": 64, + "learning_rate": 0.001, +} + +# Train the model with the new hyperparameters +train_transformer(transformer_model_2, data, transformer_model_hyperparams_2) + +# %% Evaluate the models +# Note: This is a simplified example. In a real project, you might want to evaluate the models on a validation set and compare their performance. + +# You might have a function like this one to compute the validation loss +def compute_validation_loss(model, validation_data): + # Compute the validation loss for the model + # This function is not implemented in this example + pass + +# Compute the validation loss for the original model and the new model +validation_loss = compute_validation_loss(trading_agent, validation_data) +validation_loss_2 = compute_validation_loss(trading_agent_2, validation_data) + +print(f'Validation loss for the original model: {validation_loss}') +print(f'Validation loss for the new model: {validation_loss_2}') diff --git a/notebooks/model_training.py b/notebooks/model_training.py index e7a1dbf..9cdb074 100644 --- a/notebooks/model_training.py +++ b/notebooks/model_training.py @@ -1,8 +1,10 @@ # %% Import required packages import torch +from torch.optim import Adam +from torch.nn import MSELoss from src.models.transformer_model import TransformerModel from src.models.rl_model import RLModel -from src.models.trading_agent import TradingAgent +from models.trading_model import TradingAgent from src.training.train import train_transformer, train_rl from src.data.data_preprocessing import load_processed_data @@ -10,13 +12,18 @@ from src.data.data_preprocessing import load_processed_data device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # %% Load processed data -data = load_processed_data('./data/processed/processed_data.csv') +train_data = load_processed_data('../data/processed/train_data.csv') +test_data = load_processed_data('../data/processed/test_data.csv') # %% Initialize models transformer_model = TransformerModel().to(device) rl_model = RLModel().to(device) trading_agent = TradingAgent(transformer_model, rl_model) +# %% Set up the loss function and optimizer for Transformer model +criterion = MSELoss() +optimizer = Adam(transformer_model.parameters(), lr=0.001) + # %% Train Transformer Model # Set the appropriate hyperparameters transformer_model_hyperparams = { @@ -24,10 +31,15 @@ transformer_model_hyperparams = { "batch_size": 32, "learning_rate": 0.001, } -train_transformer(transformer_model, data, transformer_model_hyperparams) +train_transformer(transformer_model, train_data, criterion, optimizer, transformer_model_hyperparams) + +# %% Evaluate Transformer Model on Test Data +# After training, it's a good practice to evaluate your model on a separate test set. +test_loss = evaluate_transformer(transformer_model, test_data, criterion) +print('Test Loss:', test_loss) # %% Save Transformer Model -torch.save(transformer_model.state_dict(), './models/transformer_model.pth') +torch.save(transformer_model.state_dict(), '../models/transformer_model.pth') # %% Train RL Model # Set the appropriate hyperparameters @@ -40,7 +52,12 @@ rl_model_hyperparams = { "epsilon_end": 0.01, # minimum exploration rate "epsilon_decay": 0.995, # exponential decay rate for exploration probability } -train_rl(trading_agent, data, rl_model_hyperparams) +train_rl(trading_agent, train_data, rl_model_hyperparams) + +# %% Evaluate RL Model on Test Data +# After training, it's a good practice to evaluate your model on a separate test set. +test_reward = evaluate_rl(trading_agent, test_data) +print('Test Reward:', test_reward) # %% Save RL Model -torch.save(rl_model.state_dict(), './models/rl_model.pth') +torch.save(rl_model.state_dict(), '../models/rl_model.pth') diff --git a/notebooks/demo.py b/src/agents/__init__.py similarity index 100% rename from notebooks/demo.py rename to src/agents/__init__.py diff --git a/src/agents/agent_example.py b/src/agents/agent_example.py new file mode 100644 index 0000000..1b262bc --- /dev/null +++ b/src/agents/agent_example.py @@ -0,0 +1,74 @@ +from src.models.trading_model import TradingModel +from src.training import train_model +from src.utils import save_model, load_model + +class ExampleAgent: + def __init__(self, config): + """ + Initializes the ExampleAgent with a configuration dictionary. + + Parameters: + config (dict): The configuration dictionary containing the parameters for agent and models. + """ + self.transformer_config = config['transformer'] + self.rl_config = config['rl'] + self.transformer_model = TransformerModel(self.transformer_config) + self.rl_model = RLModel(self.rl_config) + self.trading_agent = TradingModel(self.transformer_model, self.rl_model) + + def train(self, training_data): + """ + Trains the agent on the provided training data. + + Parameters: + training_data (np.array): The training data for the agent. + """ + # Training the transformer model + self.transformer_model = train_model(self.transformer_model, training_data, self.transformer_config) + + # Training the RL model + self.rl_model = train_model(self.rl_model, training_data, self.rl_config) + + # Update the trading_agent with the trained models + self.trading_agent = TradingModel(self.transformer_model, self.rl_model) + + def act(self, obs): + """ + Makes a decision based on the provided observation. + + Parameters: + obs (np.array): The observation for the agent. + + Returns: + int: The action chosen by the agent (0: hold, 1: buy, 2: sell). + """ + return self.trading_agent.make_decision(obs) + + def save_models(self, path): + """ + Saves the models to the specified path. + + Parameters: + path (str): The path to save the models. + """ + # Save the transformer model + save_model(self.transformer_model, path + "/transformer_model.pth") + + # Save the RL model + save_model(self.rl_model, path + "/rl_model.pth") + + def load_models(self, path): + """ + Loads the models from the specified path. + + Parameters: + path (str): The path to load the models from. + """ + # Load the transformer model + self.transformer_model = load_model(TransformerModel(self.transformer_config), path + "/transformer_model.pth") + + # Load the RL model + self.rl_model = load_model(RLModel(self.rl_config), path + "/rl_model.pth") + + # Update the trading_agent with the loaded models + self.trading_agent = TradingModel(self.transformer_model, self.rl_model) diff --git a/src/data/data_collection.py b/src/data/data_collection.py index 99e9e5a..0ab4de0 100644 --- a/src/data/data_collection.py +++ b/src/data/data_collection.py @@ -1,19 +1,45 @@ +import requests +import json +import os +import pandas as pd +from pandas_datareader import data as pdr +from datetime import datetime, timedelta import yfinance as yf -def collect_data(tickers, start_date, end_date): - """ - Collects data for the given tickers and date range. +yf.pdr_override() # Override pandas_datareader's get_data_yahoo() method - Parameters: - tickers (list of str): List of ticker symbols. - start_date (str): Start date in format 'YYYY-MM-DD'. - end_date (str): End date in format 'YYYY-MM-DD'. +class DataCollector: + def __init__(self, tickers, start_date, end_date): + self.tickers = tickers + self.start_date = start_date + self.end_date = end_date - Returns: - dict: Dictionary where the keys are ticker symbols and the values are pandas DataFrames of the price data. - """ - data = {} - for ticker in tickers: - df = yf.download(ticker, start=start_date, end=end_date) - data[ticker] = df - return data + def fetch_data_from_yahoo(self): + data = pdr.get_data_yahoo(self.tickers, start=self.start_date, end=self.end_date) + return data + + def fetch_data_from_binance(self, symbol, interval='1d', limit=500): + url = f'https://api.binance.com/api/v3/klines?symbol={symbol}&interval={interval}&limit={limit}' + response = requests.get(url) + response.raise_for_status() + data = response.json() + return data + + def save_data(self, data, filename): + if isinstance(data, pd.DataFrame): + data.to_csv(filename) + else: + with open(filename, 'w') as f: + json.dump(data, f) + +if __name__ == "__main__": + tickers = ['AAPL', 'GOOGL', 'TSLA'] + data_collector = DataCollector(tickers, '2020-01-01', '2021-12-31') + + # Fetch and save stock data from Yahoo Finance + stock_data = data_collector.fetch_data_from_yahoo() + data_collector.save_data(stock_data, 'stock_data.csv') + + # Fetch and save cryptocurrency data from Binance + crypto_data = data_collector.fetch_data_from_binance('BTCUSDT') + data_collector.save_data(crypto_data, 'crypto_data.json') diff --git a/src/data/data_preprocessing.py b/src/data/data_preprocessing.py index 9e7c5da..617ad9d 100644 --- a/src/data/data_preprocessing.py +++ b/src/data/data_preprocessing.py @@ -1,16 +1,39 @@ +import pandas as pd from sklearn.preprocessing import MinMaxScaler -def preprocess_data(data): - """ - Preprocesses the collected data. +class DataPreprocessor: + def __init__(self, filename, file_format='csv'): + if file_format == 'csv': + self.data = pd.read_csv(filename) + elif file_format == 'json': + self.data = pd.read_json(filename) + else: + raise ValueError(f"Unsupported file format: {file_format}") - Parameters: - data (dict): The data collected from collect_data function. Keys are tickers and values are pandas DataFrames. + def drop_nulls(self): + self.data.dropna(inplace=True) - Returns: - dict: Preprocessed data where the 'Close' prices have been scaled to be between 0 and 1. - """ - scaler = MinMaxScaler() - for ticker in data: - data[ticker]['Close'] = scaler.fit_transform(data[ticker][['Close']]) - return data + def scale_data(self, columns): + scaler = MinMaxScaler() + self.data[columns] = scaler.fit_transform(self.data[columns]) + + def preprocess_data(self): + # Call preprocessing steps + self.drop_nulls() + + # Assume that we want to scale all numerical columns + numerical_cols = self.data.select_dtypes(include=['float64', 'int']).columns.tolist() + self.scale_data(numerical_cols) + + def save_data(self, filename, file_format='csv'): + if file_format == 'csv': + self.data.to_csv(filename, index=False) + elif file_format == 'json': + self.data.to_json(filename) + else: + raise ValueError(f"Unsupported file format: {file_format}") + +if __name__ == "__main__": + data_preprocessor = DataPreprocessor('stock_data.csv') + data_preprocessor.preprocess_data() + data_preprocessor.save_data('processed_data.csv') diff --git a/src/evaluation/evaluate.py b/src/evaluation/evaluate.py index 0590868..00f2fe3 100644 --- a/src/evaluation/evaluate.py +++ b/src/evaluation/evaluate.py @@ -3,6 +3,7 @@ from torch.utils.data import DataLoader from src.models.transformer_model import TransformerModel from src.models.rl_model import RLModel from src.data.data_preprocessing import Dataset +from sklearn.metrics import accuracy_score def evaluate_transformer_model(transformer_model, test_data): """ @@ -26,8 +27,10 @@ def evaluate_transformer_model(transformer_model, test_data): transformer_model.eval() # Evaluation loop + total_loss = 0 + total_correct = 0 + total_samples = 0 with torch.no_grad(): - total_loss = 0 for i, (inputs, targets) in enumerate(dataloader): inputs, targets = inputs.to(device), targets.to(device) @@ -36,12 +39,17 @@ def evaluate_transformer_model(transformer_model, test_data): # Compute loss loss = criterion(outputs, targets) - total_loss += loss.item() - # Compute average loss + # Compute accuracy + _, predicted = torch.max(outputs, 1) + total_correct += (predicted == targets).sum().item() + total_samples += targets.size(0) + + # Compute average loss and accuracy average_loss = total_loss / len(dataloader) - print(f'Average loss: {average_loss}') + accuracy = total_correct / total_samples + print(f'Average loss: {average_loss}, Accuracy: {accuracy}') def evaluate_rl_model(rl_model, env, episodes): """ @@ -52,17 +60,18 @@ def evaluate_rl_model(rl_model, env, episodes): env (gym.Env): The Gym environment to use for evaluation. episodes (int): The number of episodes to evaluate for. """ - total_rewards = 0 + total_rewards = [] for i_episode in range(episodes): state = env.reset() done = False + episode_reward = 0 while not done: action = rl_model.predict(state) state, reward, done, _ = env.step(action) - total_rewards += reward - - print(f'Episode: {i_episode+1}, Reward: {reward}') + episode_reward += reward + total_rewards.append(episode_reward) + print(f'Episode: {i_episode+1}, Total reward: {episode_reward}') # Compute average reward - average_reward = total_rewards / episodes - print(f'Average reward: {average_reward}') + average_reward = sum(total_rewards) / episodes + print(f'Average total reward: {average_reward}') diff --git a/src/main.py b/src/main.py index 8de2283..c6db174 100644 --- a/src/main.py +++ b/src/main.py @@ -1,40 +1,52 @@ -import argparse from src.data import data_collection, data_preprocessing from src.models import transformer_model, rl_model, trading_model from src.training import train from src.evaluation import evaluate from src.utils import utils, metrics -def main(args): - # Set seed for reproducibility - utils.seed_everything(args.seed) +def main(config): + """ + The main function that drives the training and evaluation of the models. - # Data Collection - raw_data = data_collection.collect_data(args.data_source) + Parameters: + config (dict): The configuration dictionary. + """ + try: + # Set seed for reproducibility + utils.seed_everything(config['seed']) - # Data Preprocessing - processed_data = data_preprocessing.preprocess_data(raw_data) + # Data Collection - Retrieve the data from the specified source + raw_data = data_collection.collect_data(config['data_source']) - # Model Creation - transformer = transformer_model.TransformerModel(args.transformer_config) - rl_agent = rl_model.RLModel(args.rl_config) - trading_agent = trading_model.TradingAgent(transformer, rl_agent) + # Data Preprocessing - Process the raw data for the models + processed_data = data_preprocessing.preprocess_data(raw_data) - # Model Training - train.train(trading_agent, processed_data, args.training_config) + # Model Creation - Create the Transformer, RL, and Trading models + transformer = transformer_model.TransformerModel(config['transformer_config']) + rl_agent = rl_model.RLModel(config['rl_config']) + trading_agent = trading_model.TradingAgent(transformer, rl_agent) - # Model Evaluation - evaluation_results = evaluate.evaluate(trading_agent, processed_data, metrics) + # Model Training - Train the trading agent on the processed data + train.train(trading_agent, processed_data, config['training_config']) - print(evaluation_results) + # Model Evaluation - Evaluate the trained agent and print the results + evaluation_results = evaluate.evaluate(trading_agent, processed_data, metrics) + + print(evaluation_results) + except Exception as e: + print(f"An error occurred during the execution: {e}") if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.") - parser.add_argument("--data_source", type=str, default="data/raw/", help="Data source for the trading data.") - parser.add_argument("--transformer_config", type=str, default="config/transformer.json", help="Path to the Transformer model configuration file.") - parser.add_argument("--rl_config", type=str, default="config/rl.json", help="Path to the RL model configuration file.") - parser.add_argument("--training_config", type=str, default="config/training.json", help="Path to the training configuration file.") - args = parser.parse_args() + try: + # Configuration dictionary + config = { + 'seed': 42, + 'data_source': 'data/raw/', + 'transformer_config': 'config/transformer.json', + 'rl_config': 'config/rl.json', + 'training_config': 'config/training.json' + } - main(args) + main(config) + except Exception as e: + print(f"An error occurred during the setup or during main execution: {e}") diff --git a/src/models/rl_model.py b/src/models/rl_model.py index 14ee909..f5e534e 100644 --- a/src/models/rl_model.py +++ b/src/models/rl_model.py @@ -1,16 +1,24 @@ -from stable_baselines3 import PPO +from stable_baselines3 import PPO, A2C, DDPG from stable_baselines3.common.envs import DummyVecEnv class RLModel: - def __init__(self, env): + def __init__(self, env, algorithm='PPO'): """ - Initializes the RLModel with a given environment. + Initializes the RLModel with a given environment and algorithm. Parameters: env (gym.Env): The Gym environment to use for training. + algorithm (str): The RL algorithm to use for training. Should be one of ['PPO', 'A2C', 'DDPG']. """ self.env = DummyVecEnv([lambda: env]) # The environment must be vectorized - self.model = PPO('MlpPolicy', self.env, verbose=1) + if algorithm == 'PPO': + self.model = PPO('MlpPolicy', self.env, verbose=1) + elif algorithm == 'A2C': + self.model = A2C('MlpPolicy', self.env, verbose=1) + elif algorithm == 'DDPG': + self.model = DDPG('MlpPolicy', self.env, verbose=1) + else: + raise ValueError(f'Unknown algorithm {algorithm}. Should be one of ["PPO", "A2C", "DDPG"]') def train(self, timesteps): """ @@ -33,3 +41,21 @@ class RLModel: """ action, _states = self.model.predict(obs) return action + + def save(self, path): + """ + Saves the model at the given path. + + Parameters: + path (str): The path to save the model at. + """ + self.model.save(path) + + def load(self, path): + """ + Loads the model from the given path. + + Parameters: + path (str): The path to load the model from. + """ + self.model = self.model.load(path) diff --git a/src/models/trading_agent.py b/src/models/trading_model.py similarity index 57% rename from src/models/trading_agent.py rename to src/models/trading_model.py index 7577cfb..026bc6e 100644 --- a/src/models/trading_agent.py +++ b/src/models/trading_model.py @@ -1,4 +1,4 @@ -class TradingAgent: +class TradingModel: def __init__(self, transformer_model, rl_model): """ Initializes the TradingAgent with the Transformer and RL models. @@ -10,22 +10,25 @@ class TradingAgent: self.transformer_model = transformer_model self.rl_model = rl_model - def make_decision(self, text, obs): + def make_decision(self, stock_prices, obs): """ - Makes a trading decision based on the given text and observations. + Makes a trading decision based on the given stock prices and observations. Parameters: - text (str): The text to feed to the Transformer model. + stock_prices (np.array): The stock price sequence to feed to the Transformer model. obs (np.array): The observations to feed to the RL model. Returns: int: The action chosen by the agent (0: hold, 1: buy, 2: sell). """ # Get embeddings from transformer model - embeddings = self.transformer_model.get_embeddings(text) + embeddings = self.transformer_model.get_embeddings(stock_prices) - # Combine embeddings with observations - combined_input = np.concatenate((embeddings.detach().numpy(), obs)) + # Compute a single vector representation of the embeddings + vector_representation = embeddings.mean(dim=0) + + # Combine vector representation with observations + combined_input = np.concatenate((vector_representation.detach().numpy(), obs)) # Get action from RL model action = self.rl_model.predict(combined_input) diff --git a/src/models/transformer_model.py b/src/models/transformer_model.py index 8200809..ac7a3b6 100644 --- a/src/models/transformer_model.py +++ b/src/models/transformer_model.py @@ -1,20 +1,46 @@ -from transformers import BertModel, BertTokenizer +from transformers import AutoModel, AutoTokenizer class TransformerModel: - def __init__(self, pretrained_model_name='bert-base-uncased'): - self.tokenizer = BertTokenizer.from_pretrained(pretrained_model_name) - self.model = BertModel.from_pretrained(pretrained_model_name) + def __init__(self, pretrained_model_name='bert-base-uncased', max_length=512, mode='text'): + self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) + self.model = AutoModel.from_pretrained(pretrained_model_name) + self.max_length = max_length + self.mode = mode - def get_embeddings(self, text): + def preprocess_data(self, data): + """ + Preprocesses the input data, converting it into tensors that can be fed into the model. + + Parameters: + data (str, list[str], or pd.Series): Data to preprocess. Can be a single string, a list of strings, or a pandas Series of numerical values. + + Returns: + torch.Tensor: Preprocessed tensors ready for model input. + """ + if self.mode == 'text': + if isinstance(data, str): + data = [data] + inputs = self.tokenizer(data, return_tensors='pt', truncation=True, padding=True, max_length=self.max_length) + elif self.mode == 'time_series': + # Ensure data is a pandas Series + if not isinstance(data, pd.Series): + data = pd.Series(data) + inputs = torch.tensor(data.values).float().unsqueeze(0) + else: + raise ValueError(f"Invalid mode: {self.mode}. Mode must be either 'text' or 'time_series'.") + + return inputs + + def get_embeddings(self, data): """ Returns the embeddings generated by the transformer model. Parameters: - text (str): Text to get embeddings for. + data (str, list[str], or pd.Series): Data to get embeddings for. Can be a single string, a list of strings, or a pandas Series of numerical values. Returns: - torch.Tensor: Embeddings for the input text. + torch.Tensor: Embeddings for the input data. """ - inputs = self.tokenizer(text, return_tensors='pt', truncation=True, padding=True) + inputs = self.preprocess_data(data) outputs = self.model(**inputs) return outputs.last_hidden_state diff --git a/src/training/train.py b/src/training/train.py index 5509ed2..9e390df 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -22,8 +22,8 @@ def train_transformer_model(transformer_model, train_data, epochs, learning_rate transformer_model.to(device) # Define loss function and optimizer - criterion = torch.nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(transformer_model.parameters(), lr=learning_rate) + criterion = torch.nn.MSELoss() # Change to mean squared error loss for regression + optimizer = torch.optim.AdamW(transformer_model.parameters(), lr=learning_rate) # Use AdamW optimizer # Training loop for epoch in range(epochs): diff --git a/src/utils/metrics.py b/src/utils/metrics.py index e2c962c..b0ed86f 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -1,4 +1,5 @@ -from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score +from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix +import numpy as np def compute_classification_metrics(y_true, y_pred): """ @@ -15,8 +16,9 @@ def compute_classification_metrics(y_true, y_pred): precision = precision_score(y_true, y_pred, average='weighted', zero_division=0) recall = recall_score(y_true, y_pred, average='weighted', zero_division=0) f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0) + conf_matrix = confusion_matrix(y_true, y_pred) - return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1} + return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'confusion_matrix': conf_matrix} def compute_reward_metrics(total_rewards, num_episodes): """ @@ -32,5 +34,6 @@ def compute_reward_metrics(total_rewards, num_episodes): average_reward = sum(total_rewards) / num_episodes max_reward = max(total_rewards) min_reward = min(total_rewards) + std_dev_reward = np.std(total_rewards) - return {'average_reward': average_reward, 'max_reward': max_reward, 'min_reward': min_reward} + return {'average_reward': average_reward, 'max_reward': max_reward, 'min_reward': min_reward, 'std_dev_reward': std_dev_reward} diff --git a/src/utils/utils.py b/src/utils/utils.py index aa91f2a..ed52b96 100644 --- a/src/utils/utils.py +++ b/src/utils/utils.py @@ -1,3 +1,4 @@ +import json import numpy as np import torch from sklearn.preprocessing import MinMaxScaler @@ -28,25 +29,30 @@ def scale_data(data): def save_model(model, path): """ - Save a PyTorch model. + Save a PyTorch model along with its configuration. Parameters: model (torch.nn.Module): The model to save. path (str): The path where to save the model. """ - torch.save(model.state_dict(), path) + torch.save({ + 'model_state_dict': model.state_dict(), + 'config': model.config + }, path) -def load_model(model, path): +def load_model(model_class, path): """ - Load a PyTorch model. + Load a PyTorch model along with its configuration. Parameters: - model (torch.nn.Module): The model to load. + model_class (class): The class of the model to load. path (str): The path from where to load the model. Returns: - torch.nn.Module: The loaded model. + model_class: The loaded model. """ - model.load_state_dict(torch.load(path)) + checkpoint = torch.load(path) + model = model_class(checkpoint['config']) + model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..244b512 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,30 @@ +import pandas as pd +import numpy as np +from src.data import data_collection, data_preprocessing + +def test_collect_data(): + url = "https://github.com/plotly/datasets/raw/master/tesla-stock-price.csv" + data = data_collection.collect_data(url) + + assert isinstance(data, pd.DataFrame) # Assert that a DataFrame is returned + assert data.shape[0] > 0 # Assert that the DataFrame is not empty + assert set(data.columns) == set(['Date', 'Open', 'High', 'Low', 'Close', 'Volume']) # Assert the columns are as expected + +def test_preprocess_data(): + url = "https://github.com/plotly/datasets/raw/master/tesla-stock-price.csv" + raw_data = data_collection.collect_data(url) + processed_data = data_preprocessing.preprocess_data(raw_data) + + assert isinstance(processed_data, pd.DataFrame) + assert processed_data.shape[0] > 0 # Assert that the dataframe is not empty + # Assert that the processed data columns are as expected + assert set(processed_data.columns) == set(['Processed_Open', 'Processed_High', 'Processed_Low', 'Processed_Close', 'Processed_Volume']) + assert pd.api.types.is_numeric_dtype(processed_data["Processed_Open"]) # Assert that the 'Processed_Open' column is numeric + assert pd.api.types.is_numeric_dtype(processed_data["Processed_High"]) # Assert that the 'Processed_High' column is numeric + assert pd.api.types.is_numeric_dtype(processed_data["Processed_Low"]) # Assert that the 'Processed_Low' column is numeric + assert pd.api.types.is_numeric_dtype(processed_data["Processed_Close"]) # Assert that the 'Processed_Close' column is numeric + assert pd.api.types.is_numeric_dtype(processed_data["Processed_Volume"]) # Assert that the 'Processed_Volume' column is numeric + +if __name__ == "__main__": + test_collect_data() + test_preprocess_data() diff --git a/tests/test_data_collection.py b/tests/test_data_collection.py deleted file mode 100644 index a0be47d..0000000 --- a/tests/test_data_collection.py +++ /dev/null @@ -1,13 +0,0 @@ -import pytest -from src.data import data_collection - -def test_collect_data(): - # Test the collect_data function - data = data_collection.collect_data('path_to_test_data') - - # Check that the data has the expected shape - assert data.shape == (expected_number_of_rows, expected_number_of_columns) - - # Check that the data has the expected columns - expected_columns = ['column1', 'column2', 'column3'] - assert all(column in data.columns for column in expected_columns) diff --git a/tests/test_data_preprocessing.py b/tests/test_data_preprocessing.py deleted file mode 100644 index c8bd26f..0000000 --- a/tests/test_data_preprocessing.py +++ /dev/null @@ -1,26 +0,0 @@ -import pytest -import pandas as pd -from src.data import data_preprocessing - -def test_preprocess_data(): - # create a mock data - raw_data = pd.DataFrame({ - 'Open': [1.0, 2.0, 3.0, 4.0, 5.0], - 'High': [1.1, 2.1, 3.1, 4.1, 5.1], - 'Low': [0.9, 1.9, 2.9, 3.9, 4.9], - 'Close': [1.0, 2.0, 3.0, 4.0, 5.0], - 'Volume': [1000, 2000, 3000, 4000, 5000] - }) - - # perform preprocessing - processed_data = data_preprocessing.preprocess_data(raw_data) - - # check that the data has the expected columns - expected_columns = ['Open', 'High', 'Low', 'Close', 'Volume'] - assert all(column in processed_data.columns for column in expected_columns) - - # check the shape of the data - assert processed_data.shape == raw_data.shape - - # check that values are normalized (within a certain range, e.g. -1.0 to 1.0) - assert all(-1.0 <= value <= 1.0 for value in processed_data.values.flatten()) diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py new file mode 100644 index 0000000..234b2f2 --- /dev/null +++ b/tests/test_evaluation.py @@ -0,0 +1,35 @@ +import torch +from src.models import rl_model, transformer_model +from src.agents import Agent +from src.evaluation import evaluate +from src.utils import seed_everything, scale_data + +def test_evaluation(): + # Set a seed for reproducibility + seed_everything(42) + + # Create mock data + input_data = torch.rand((100, 10)) + + # Scale the data + input_data, _ = scale_data(input_data) + + # Instantiate the models + transformer = transformer_model.TransformerModel() + rl = rl_model.RLModel() + + # Instantiate the agent + agent = Agent(transformer_model=transformer, rl_model=rl) + + # Run the evaluation + evaluation_results = evaluate(agent, input_data) + + # Check the type of evaluation results + assert isinstance(evaluation_results, dict) + + # Check that the evaluation results contain expected keys + expected_keys = ["average_reward", "total_reward"] + assert all(key in evaluation_results for key in expected_keys) + +if __name__ == "__main__": + test_evaluation() diff --git a/tests/test_metrics.py b/tests/test_metrics.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..dc5047c --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,71 @@ +import torch +from src.models import rl_model, trading_model, transformer_model + +def test_rl_model(): + # Create mock input data + input_data = torch.rand((1, 10)) + + # Instantiate the model + model = rl_model.RLModel() + + # Forward pass + outputs = model(input_data) + + # Check output dimensions + assert outputs.size() == torch.Size([1, model.output_size]) + + # Check that the model is on the correct device + assert outputs.device == model.device + +def test_trading_model(): + # Create mock input data + input_data = torch.rand((1, 10)) + + # Instantiate the model + model = trading_model.TradingModel() + + # Forward pass + outputs = model(input_data) + + # Check output dimensions + assert outputs.size() == torch.Size([1, model.output_size]) + + # Check that the model is on the correct device + assert outputs.device == model.device + +def test_transformer_model(): + # Create mock input data + input_ids = torch.randint(0, 100, (1, 20)) + attention_mask = torch.ones((1, 20)) + + # Instantiate the model + model = transformer_model.TransformerModel() + + # Forward pass + outputs = model(input_ids, attention_mask) + + # Check output dimensions + assert outputs.size() == torch.Size([1, 20, model.hidden_size]) + + # Check that the model is on the correct device + assert outputs.device == model.device + +def test_model_save_load(): + # Instantiate the model + model = transformer_model.TransformerModel() + + # Save the model + model.save_pretrained('test_model') + + # Load the model + loaded_model = transformer_model.TransformerModel.from_pretrained('test_model') + + # Check that the loaded model has the same parameters as the original model + for p1, p2 in zip(model.parameters(), loaded_model.parameters()): + assert torch.all(p1.eq(p2)) + +if __name__ == "__main__": + test_rl_model() + test_trading_model() + test_transformer_model() + test_model_save_load() diff --git a/tests/test_rl_model.py b/tests/test_rl_model.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_trading_model.py b/tests/test_trading_model.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 0000000..de0a320 --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,38 @@ +import torch +from src.models import rl_model, transformer_model +from src.agents import Agent +from src.training import train +from src.utils import seed_everything, scale_data + +def test_training(): + # Set a seed for reproducibility + seed_everything(42) + + # Create mock data + input_data = torch.rand((100, 10)) + target_data = torch.rand((100, 10)) + + # Scale the data + input_data, _ = scale_data(input_data) + target_data, _ = scale_data(target_data) + + # Instantiate the models + transformer = transformer_model.TransformerModel() + rl = rl_model.RLModel() + + # Instantiate the agent + agent = Agent(transformer_model=transformer, rl_model=rl) + + # Define loss function and optimizer + criterion = torch.nn.MSELoss() + optimizer = torch.optim.Adam(agent.parameters()) + + # Train the agent + train(agent, input_data, target_data, criterion, optimizer) + + # Check that model weights have been updated + initial_weights = list(agent.parameters())[0].detach().clone() + assert not torch.all(torch.eq(initial_weights, list(agent.parameters())[0])) + +if __name__ == "__main__": + test_training() diff --git a/tests/test_transformer_model.py b/tests/test_transformer_model.py deleted file mode 100644 index e6dca11..0000000 --- a/tests/test_transformer_model.py +++ /dev/null @@ -1,34 +0,0 @@ -import pytest -import torch -from src.models import transformer_model - -def test_transformer_model(): - # Create mock input data - input_ids = torch.randint(0, 100, (1, 20)) - attention_mask = torch.ones((1, 20)) - - # Instantiate the model - model = transformer_model.TransformerModel() - - # Forward pass - outputs = model(input_ids, attention_mask) - - # Check output dimensions - assert outputs.size() == torch.Size([1, 20, model.hidden_size]) - - # Check that the model is on the correct device - assert outputs.device == model.device - -def test_model_save_load(): - # Instantiate the model - model = transformer_model.TransformerModel() - - # Save the model - model.save_pretrained('test_model') - - # Load the model - loaded_model = transformer_model.TransformerModel.from_pretrained('test_model') - - # Check that the loaded model has the same parameters as the original model - for p1, p2 in zip(model.parameters(), loaded_model.parameters()): - assert torch.all(p1.eq(p2)) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..0fb4098 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,66 @@ +import torch +import numpy as np +import os +from sklearn.preprocessing import MinMaxScaler +from src.models import transformer_model +from src.utils import seed_everything, scale_data, save_model, load_model + +def test_seed_everything(): + # Set a seed + seed_everything(42) + + # Generate some random numbers + np_random_nums = np.random.rand(5) + torch_random_nums = torch.rand(5) + + # Set the seed again + seed_everything(42) + + # Generate some more random numbers + np_random_nums_2 = np.random.rand(5) + torch_random_nums_2 = torch.rand(5) + + # Check that the sets of random numbers are equal + assert np.all(np_random_nums == np_random_nums_2) + assert torch.all(torch_random_nums == torch_random_nums_2) + +def test_scale_data(): + # Create some mock data + data = np.random.rand(100, 10) + + # Scale the data + scaled_data, scaler = scale_data(data) + + # Check that the scaled data has the correct shape + assert scaled_data.shape == data.shape + + # Check that the scaled data is actually scaled + assert np.all(0 <= scaled_data) and np.all(scaled_data <= 1) + + # Check that the scaler is a MinMaxScaler + assert isinstance(scaler, MinMaxScaler) + +def test_save_and_load_model(): + # Instantiate the model + model = transformer_model.TransformerModel() + + # Save the model + save_model(model, 'test_model.pt') + + # Check that the model file was created + assert os.path.isfile('test_model.pt') + + # Load the model + loaded_model = load_model(transformer_model.TransformerModel(), 'test_model.pt') + + # Check that the loaded model has the same parameters as the original model + for p1, p2 in zip(model.parameters(), loaded_model.parameters()): + assert torch.all(p1.eq(p2)) + + # Remove the model file + os.remove('test_model.pt') + +if __name__ == "__main__": + test_seed_everything() + test_scale_data() + test_save_and_load_model()