import torch from src.models import rl_model, trading_model, transformer_model def test_rl_model(): # Create mock input data input_data = torch.rand((1, 10)) # Instantiate the model model = rl_model.RLModel() # Forward pass outputs = model(input_data) # Check output dimensions assert outputs.size() == torch.Size([1, model.output_size]) # Check that the model is on the correct device assert outputs.device == model.device def test_trading_model(): # Create mock input data input_data = torch.rand((1, 10)) # Instantiate the model model = trading_model.TradingModel() # Forward pass outputs = model(input_data) # Check output dimensions assert outputs.size() == torch.Size([1, model.output_size]) # Check that the model is on the correct device assert outputs.device == model.device def test_transformer_model(): # Create mock input data input_ids = torch.randint(0, 100, (1, 20)) attention_mask = torch.ones((1, 20)) # Instantiate the model model = transformer_model.TransformerModel() # Forward pass outputs = model(input_ids, attention_mask) # Check output dimensions assert outputs.size() == torch.Size([1, 20, model.hidden_size]) # Check that the model is on the correct device assert outputs.device == model.device def test_model_save_load(): # Instantiate the model model = transformer_model.TransformerModel() # Save the model model.save_pretrained('test_model') # Load the model loaded_model = transformer_model.TransformerModel.from_pretrained('test_model') # Check that the loaded model has the same parameters as the original model for p1, p2 in zip(model.parameters(), loaded_model.parameters()): assert torch.all(p1.eq(p2)) if __name__ == "__main__": test_rl_model() test_trading_model() test_transformer_model() test_model_save_load()