27 lines
929 B
Python
27 lines
929 B
Python
import pytest
|
|
import pandas as pd
|
|
from src.data import data_preprocessing
|
|
|
|
def test_preprocess_data():
|
|
# create a mock data
|
|
raw_data = pd.DataFrame({
|
|
'Open': [1.0, 2.0, 3.0, 4.0, 5.0],
|
|
'High': [1.1, 2.1, 3.1, 4.1, 5.1],
|
|
'Low': [0.9, 1.9, 2.9, 3.9, 4.9],
|
|
'Close': [1.0, 2.0, 3.0, 4.0, 5.0],
|
|
'Volume': [1000, 2000, 3000, 4000, 5000]
|
|
})
|
|
|
|
# perform preprocessing
|
|
processed_data = data_preprocessing.preprocess_data(raw_data)
|
|
|
|
# check that the data has the expected columns
|
|
expected_columns = ['Open', 'High', 'Low', 'Close', 'Volume']
|
|
assert all(column in processed_data.columns for column in expected_columns)
|
|
|
|
# check the shape of the data
|
|
assert processed_data.shape == raw_data.shape
|
|
|
|
# check that values are normalized (within a certain range, e.g. -1.0 to 1.0)
|
|
assert all(-1.0 <= value <= 1.0 for value in processed_data.values.flatten())
|