auto-trading/src/utils/utils.py

53 lines
1.1 KiB
Python

import numpy as np
import torch
from sklearn.preprocessing import MinMaxScaler
def seed_everything(seed):
"""
Set a seed for all random number generators to ensure reproducibility.
Parameters:
seed (int): The seed to use.
"""
np.random.seed(seed)
torch.manual_seed(seed)
def scale_data(data):
"""
Scale data using MinMaxScaler.
Parameters:
data (np.array): The data to scale.
Returns:
np.array: The scaled data.
"""
scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(data)
return scaled_data, scaler
def save_model(model, path):
"""
Save a PyTorch model.
Parameters:
model (torch.nn.Module): The model to save.
path (str): The path where to save the model.
"""
torch.save(model.state_dict(), path)
def load_model(model, path):
"""
Load a PyTorch model.
Parameters:
model (torch.nn.Module): The model to load.
path (str): The path from where to load the model.
Returns:
torch.nn.Module: The loaded model.
"""
model.load_state_dict(torch.load(path))
model.eval()
return model