auto-trading/tests/test_evaluation.py

36 lines
1002 B
Python

import torch
from src.models import rl_model, transformer_model
from src.agents import Agent
from src.evaluation import evaluate
from src.utils import seed_everything, scale_data
def test_evaluation():
# Set a seed for reproducibility
seed_everything(42)
# Create mock data
input_data = torch.rand((100, 10))
# Scale the data
input_data, _ = scale_data(input_data)
# Instantiate the models
transformer = transformer_model.TransformerModel()
rl = rl_model.RLModel()
# Instantiate the agent
agent = Agent(transformer_model=transformer, rl_model=rl)
# Run the evaluation
evaluation_results = evaluate(agent, input_data)
# Check the type of evaluation results
assert isinstance(evaluation_results, dict)
# Check that the evaluation results contain expected keys
expected_keys = ["average_reward", "total_reward"]
assert all(key in evaluation_results for key in expected_keys)
if __name__ == "__main__":
test_evaluation()