67 lines
1.9 KiB
Python
67 lines
1.9 KiB
Python
import torch
|
|
import numpy as np
|
|
import os
|
|
from sklearn.preprocessing import MinMaxScaler
|
|
from src.models import transformer_model
|
|
from src.utils import seed_everything, scale_data, save_model, load_model
|
|
|
|
def test_seed_everything():
|
|
# Set a seed
|
|
seed_everything(42)
|
|
|
|
# Generate some random numbers
|
|
np_random_nums = np.random.rand(5)
|
|
torch_random_nums = torch.rand(5)
|
|
|
|
# Set the seed again
|
|
seed_everything(42)
|
|
|
|
# Generate some more random numbers
|
|
np_random_nums_2 = np.random.rand(5)
|
|
torch_random_nums_2 = torch.rand(5)
|
|
|
|
# Check that the sets of random numbers are equal
|
|
assert np.all(np_random_nums == np_random_nums_2)
|
|
assert torch.all(torch_random_nums == torch_random_nums_2)
|
|
|
|
def test_scale_data():
|
|
# Create some mock data
|
|
data = np.random.rand(100, 10)
|
|
|
|
# Scale the data
|
|
scaled_data, scaler = scale_data(data)
|
|
|
|
# Check that the scaled data has the correct shape
|
|
assert scaled_data.shape == data.shape
|
|
|
|
# Check that the scaled data is actually scaled
|
|
assert np.all(0 <= scaled_data) and np.all(scaled_data <= 1)
|
|
|
|
# Check that the scaler is a MinMaxScaler
|
|
assert isinstance(scaler, MinMaxScaler)
|
|
|
|
def test_save_and_load_model():
|
|
# Instantiate the model
|
|
model = transformer_model.TransformerModel()
|
|
|
|
# Save the model
|
|
save_model(model, 'test_model.pt')
|
|
|
|
# Check that the model file was created
|
|
assert os.path.isfile('test_model.pt')
|
|
|
|
# Load the model
|
|
loaded_model = load_model(transformer_model.TransformerModel(), 'test_model.pt')
|
|
|
|
# Check that the loaded model has the same parameters as the original model
|
|
for p1, p2 in zip(model.parameters(), loaded_model.parameters()):
|
|
assert torch.all(p1.eq(p2))
|
|
|
|
# Remove the model file
|
|
os.remove('test_model.pt')
|
|
|
|
if __name__ == "__main__":
|
|
test_seed_everything()
|
|
test_scale_data()
|
|
test_save_and_load_model()
|