import torch from src.models import rl_model, transformer_model from src.agents import Agent from src.evaluation import evaluate from src.utils import seed_everything, scale_data def test_evaluation(): # Set a seed for reproducibility seed_everything(42) # Create mock data input_data = torch.rand((100, 10)) # Scale the data input_data, _ = scale_data(input_data) # Instantiate the models transformer = transformer_model.TransformerModel() rl = rl_model.RLModel() # Instantiate the agent agent = Agent(transformer_model=transformer, rl_model=rl) # Run the evaluation evaluation_results = evaluate(agent, input_data) # Check the type of evaluation results assert isinstance(evaluation_results, dict) # Check that the evaluation results contain expected keys expected_keys = ["average_reward", "total_reward"] assert all(key in evaluation_results for key in expected_keys) if __name__ == "__main__": test_evaluation()