36 lines
1002 B
Python
36 lines
1002 B
Python
import torch
|
|
from src.models import rl_model, transformer_model
|
|
from src.agents import Agent
|
|
from src.evaluation import evaluate
|
|
from src.utils import seed_everything, scale_data
|
|
|
|
def test_evaluation():
|
|
# Set a seed for reproducibility
|
|
seed_everything(42)
|
|
|
|
# Create mock data
|
|
input_data = torch.rand((100, 10))
|
|
|
|
# Scale the data
|
|
input_data, _ = scale_data(input_data)
|
|
|
|
# Instantiate the models
|
|
transformer = transformer_model.TransformerModel()
|
|
rl = rl_model.RLModel()
|
|
|
|
# Instantiate the agent
|
|
agent = Agent(transformer_model=transformer, rl_model=rl)
|
|
|
|
# Run the evaluation
|
|
evaluation_results = evaluate(agent, input_data)
|
|
|
|
# Check the type of evaluation results
|
|
assert isinstance(evaluation_results, dict)
|
|
|
|
# Check that the evaluation results contain expected keys
|
|
expected_keys = ["average_reward", "total_reward"]
|
|
assert all(key in evaluation_results for key in expected_keys)
|
|
|
|
if __name__ == "__main__":
|
|
test_evaluation()
|