Database TDD Part 26: Prepared Statements

by Jeff Langr

December 13, 2005

Prepared statement support is essential in any RDBMS persistence layer. The issue is, how do I build this support safely? Even if it didn’t break the whole notion of encapsulating all this JDBC code, I wouldn’t want to entrust handing out PreparedStatement objects to clients.

As a general strategy, I can reverse the technique I use for retrieving information. Queries return rows in the form of a list of maps. Each map is a row, with column names and values being the key-value pairs in the map. The same strategy can work for inserting data using a prepared statement: clients will populate a list of map-rows. JDBC code can then manage the entire process from start to finish.

I’ll start in the simplest place, SQL generation.

SqlGeneratorTest

    public void testPreparedInsert() {
       String sql = new SqlGenerator().createPreparedInsert(TABLE, COLUMNS);
       assertEquals("insert into t (a,b) values (?,?)", sql);
    }

SqlGenerator

    public String createPreparedInsert(String table, Column[] columns) {
       return String.format("insert into %s (%s) values (%s)", table,
             createColumnList(columns), createPlaceholderList(columns));
    }
    
    private Object createPlaceholderList(Column[] columns) {
       final Transformer questions = new Transformer() {
          public String transform(Object ignored) {
             return "?";
          }
       };
       return StringUtil.commaDelimit(columns, questions);
    }

Nothing earth shattering there. These methods are getting easier to write as I add more functionality.

The new JDBC code:

JdbcAccessTest

    public void testExecutePreparedStatement() {
      List<Object> rows = new ArrayList<Object>();
      Map row1 = new HashMap();
      row1.put(COLUMN_NAME, VALUE1);
      Map row2 = new HashMap();
      row2.put(COLUMN_NAME, VALUE2);
    
      rows.add(row1);
      rows.add(row2);
    
      String sql = String.format("insert into %s values(?)", TABLE);
      access.executeAll(sql, COLUMNS, rows);
    
      assertResults(access.executeQuery(createSelectSql(), COLUMNS));
    }
    
    private void assertResults(List<Object> rows) {
      assertEquals(2, rows.size());
      assertContains(rows, VALUE1);
      assertContains(rows, VALUE2);
    }

JdbcAccess

    private PreparedStatement preparedStatement;
    ...
    public void executeAll(String sql, Column[] columns, List<Object> rows) {
       try {
          createPreparedStatement(sql);
    
          for (Map row : rows) {
             int i = 0;
             for (Column column : columns)
                preparedStatement.setObject(++i, row.get(column.getName()));
             preparedStatement.execute();
          }
    
          connection.close();
       }
       catch (SQLException e) {
          throw new JdbcException(sql, e);
       }
    }

OK, so here’s the problem: I’m introducing a construct that is solely for performance needs. Yet I now force the client of JdbcAccess to populate a list of map-rows in order to use a PreparedStatement. Yuk. Granted, database performance concerns are far more significant than in-memory performance. But it seems unfortunate to add performance overhead at the same time I’m trying to address it. I’ll take a look at this in either the next or a future installment.

Now back to the bane of this suite, the stub/mock friendly PersisterTest. I finally got fed up with the difficulty of understanding tests in this class. I spent about 15 minutes incrementally refactoring it so that the mock JdbcAccess definitions appear within each test method (for the most part). This makes it clear which behavior of JdbcAccess is getting mocked and which is not; I think it makes the tests much easier to follow.

PersisterTest

    package persistence;
    
    import java.util.*;
    
    import junit.framework.*;
    import persistence.types.*;
    import sql.*;
    
    public class PersisterTest extends TestCase {
       private static final String TABLE = "x";
    
       private static final Persistable RETURN_OBJECT1 = new Persistable() {
          public Object get(String key) {
             return null;
          }
       };
       private static final Persistable RETURN_OBJECT2 = new Persistable() {
          public Object get(String key) {
             return null;
          }
       };
    
       private static final String BAD_KEY = "not found";
       private static final String COLUMN1 = "a";
       private static final String COLUMN2 = "b";
       private static final String ROW1_VALUE1 = "a1";
       private static final String ROW1_VALUE1_PATTERN = "a%";
       private static final String ROW1_VALUE2 = "a2";
       private static final String ROW2_VALUE1 = "b1";
       private static final String ROW2_VALUE2 = "b2";
       private static final Column[] COLUMNS = { new StringColumn(COLUMN1),
             new StringColumn(COLUMN2) };
    
       private String lastSql;
       private PersistableMetadata metadata;
    
       private Persister persister;
       private Persistable persistable;
    
       protected void setUp() {
          metadata = new PersistableMetadata() {
             public String getTable() {
                return TABLE;
             }
    
             public String getKeyColumn() {
                return COLUMNS[0].getName();
             }
    
             public Column[] getColumns() {
                return COLUMNS;
             }
    
             public Persistable create(Map row) {
                if (row.get(COLUMN1).equals(ROW1_VALUE1)
                      && row.get(COLUMN2).equals(ROW1_VALUE2))
                   return RETURN_OBJECT1;
                if (row.get(COLUMN1).equals(ROW2_VALUE1)
                      && row.get(COLUMN2).equals(ROW2_VALUE2))
                   return RETURN_OBJECT2;
                return null;
             }
          };
    
          persistable = new Persistable() {
             public Object get(String key) {
                if (key.equals(COLUMN1))
                   return ROW1_VALUE1;
                if (key.equals(COLUMN2))
                   return ROW1_VALUE2;
                return null;
             }
          };
       }
    
       public void testSave() {
          createPersister(new JdbcAccess() {
             public void execute(String sql) {
                lastSql = sql;
             }
          });
    
          persister.save(persistable);
          assertLastSql(String.format("insert into %s (%s,%s) values ('%s','%s')",
                TABLE, COLUMN1, COLUMN2, ROW1_VALUE1, ROW1_VALUE2));
       }
    
       public void testSaveAll() {
          createPersister(new JdbcAccess() {
             public void executeAll(String sql, Column[] columns,
                   List> rows) {
                lastSql = sql;
                Assert.assertEquals(COLUMNS, columns);
                Assert.assertEquals(1, rows.size());
                Map row = rows.get(0);
                Assert.assertEquals(ROW1_VALUE1, row.get(COLUMN1));
             }
          });
    
          Collection collection = new ArrayList();
          collection.add(persistable);
          persister.saveAll(collection);
          assertEquals(String.format("insert into %s (%s,%s) values (?,?)", TABLE,
                COLUMN1, COLUMN2), lastSql);
       }
    
       public void testFindBy() {
          createPersister(new JdbcAccess() {
             public Map executeQueryExpectingOneRow(String sql,
                   Column[] columns) {
                Assert.assertEquals(COLUMNS, columns);
                lastSql = sql;
                return createRow1();
             }
          });
    
          final String key = ROW1_VALUE1;
          assertEquals(RETURN_OBJECT1, persister.find(key));
          assertLastSql(String.format("select %s,%s from %s where %s='%s'",
                COLUMN1, COLUMN2, TABLE, COLUMN1, key));
       }
    
       public void testFindNotFound() {
          createPersister(new JdbcAccess() {
             public Map executeQueryExpectingOneRow(String sql,
                   Column[] columns) {
                lastSql = sql;
                return null;
             }
          });
          assertNull(persister.find(BAD_KEY));
          assertLastSql(String.format("select %s,%s from %s where %s='%s'",
                COLUMN1, COLUMN2, TABLE, COLUMN1, BAD_KEY));
       }
    
       public void testGetAll() {
          createMockedExecuteQueryPersister();
          assertQueryResults(persister.getAll());
          assertEquals(String.format("select %s,%s from %s", COLUMN1, COLUMN2,
                TABLE), lastSql);
       }
    
       public void testFindMatches() {
          createMockedExecuteQueryPersister();
          assertQueryResults(persister.findMatches(COLUMNS[0], ROW1_VALUE1_PATTERN));
          assertLastSql(String.format("select %s,%s from %s where %s like '%s'",
                COLUMN1, COLUMN2, TABLE, COLUMN1, ROW1_VALUE1_PATTERN));
       }
    
       public void testFindWithCriteria() {
          createMockedExecuteQueryPersister();
          Criteria criteria = new EqualsCriteria(COLUMNS[0], ROW1_VALUE1);
          assertQueryResults(persister.find(criteria));
          assertLastSql(String.format("select %s,%s from %s where %s='%s'",
                COLUMN1, COLUMN2, TABLE, COLUMN1, ROW1_VALUE1));
       }
    
       private void createMockedExecuteQueryPersister() {
          createPersister(new JdbcAccess() {
             public List> executeQuery(String sql,
                   Column[] columns) {
                lastSql = sql;
                List> results = new ArrayList>();
                results.add(createRow1());
                results.add(createRow2());
                return results;
             }
          });
       }
    
       private void createPersister(JdbcAccess accessMock) {
          persister = new Persister(metadata, accessMock);
       }
    
       private void assertLastSql(String sql) {
          assertEquals(sql, lastSql);
       }
    
       private void assertQueryResults(List results) {
          assertEquals(2, results.size());
          assertTrue(results.contains(RETURN_OBJECT1));
          assertTrue(results.contains(RETURN_OBJECT2));
       }
    
       protected Map createRow1() {
          return createRow(ROW1_VALUE1, ROW1_VALUE2);
       }
    
       protected Map createRow2() {
          return createRow(ROW2_VALUE1, ROW2_VALUE2);
       }
    
       protected Map createRow(String value1, String value2) {
          Map row = new HashMap();
          row.put(COLUMN1, value1);
          row.put(COLUMN2, value2);
          return row;
       }
    }

Persister

    public void saveAll(Collection collection) {
       String sql = new SqlGenerator().createPreparedInsert(metadata.getTable(),
             metadata.getColumns());
       access.executeAll(sql, metadata.getColumns(), createInsertRows(collection));
    }
    
    private List> createInsertRows(Collection collection) {
       List> rows = new ArrayList>();
       for (T persistable : collection)
          rows.add(createInsertRow(persistable));
       return rows;
    }
    
    private Map createInsertRow(T persistable) {
       Map row = new HashMap();
       for (Column column : metadata.getColumns()) {
          Object value = persistable.get(column.getName());
          row.put(column.getName(), value);
       }
       return row;
    }

After creating the saveAll code in Persister, I did a little refactoring on the save method:

Persister

    public void save(T persistable) {
       String sql = new SqlGenerator().createInsert(metadata.getTable(),
             metadata.getColumns(), extractValues(persistable, metadata
                   .getColumns()));
       access.execute(sql);
    }
    
    private Object[] extractValues(T persistable, Column[] columns) {
       Object[] values = new Object[columns.length];
       for (int i = 0; i < columns.length; i++)
          values[i] = persistable.get(columns[i].getName());
       return values;
    }

It looks like there are some good similarities between the two save methods that I want to try to reconcile in the near future.

To test all this code out I wrote the following (live) database test.

CustomerAccessTest

    package domain;
    
    import java.util.*;
    
    import persistence.*;
    
    import junit.framework.*;
    
    public class CustomerAccessTest extends TestCase {
       private CustomerAccess access;
    
       protected void setUp() {
          access = new CustomerAccess();
          JdbcAccess jdbc = new JdbcAccess();
          jdbc.execute("truncate table " + access.getTable());
       }
    
       public void testPersist() {
          final String name = "a";
          final String id = "1";
          final int amount = 100;
    
          Customer customer = new Customer(id, name);
          customer.charge(amount);
    
          access.save(customer);
          Customer retrievedCustomer = access.find(id);
          assertEquals(id, retrievedCustomer.getId());
          assertEquals(name, retrievedCustomer.getName());
          assertEquals(amount, retrievedCustomer.getBalance());
       }
    
       public void testPersistLots() {
          final int count = 10;
          Collection customers = new ArrayList();
          for (int i = 0; i < count; i++) {
             String id = "" + i;
             Customer customer = new Customer(id, "a");
             customer.charge(i);
             customers.add(customer);
          }
    
          access.saveAll(customers);
    
          for (int i = 0; i < count; i++) {
             String id = "" + i;
             Customer retrievedCustomer = access.find(id);
             assertEquals(i, retrievedCustomer.getBalance());
          }
       }
    }

In doing so I recognized that the customer table needed to get cleared out with each execution, so I added the setUp method. Here's the implementation in the DataAccess superclass. (If you're looking at older code, note that I recognized and corrected a deficiency with my declaration of the parameterized type.)

DataAccess

    abstract public class DataAccess implements
          PersistableMetadata {
       ...
       public void saveAll(Collection collection) {
          new Persister(this).saveAll(collection);
       }
       ...
    }

I note that I've just added another "live" persistence test to CustomerAccess. This will start to increase the amount of time to execute my complete suite of tests. Still, I'm at a very comfortable ~5 seconds. I think the next time I feel compelled to add such a live "confidence" test I'll revisit what I want to do about this potential execution time bloat. Maybe it's not a concern–I'm not writing these tests for every possible DataAccess subclass. I think there are a few missing tests that I might add, but I don't know that they'll severely increase test execution time.

I'm still acting non-traditionally, by the way, in working this backward. Inside-out, some might call it. This is partly because the need for PreparedStatement support is artificial (I was too lazy to dream up and work something down from the application level). It's also because sometimes it's the easiest way for me to approach solving the problem.

Share your comment

Jeff Langr

About the Author

Jeff Langr has been building software for 40 years and writing about it heavily for 20. You can find out more about Jeff, learn from the many helpful articles and books he's written, or read one of his 1000+ combined blog (including Agile in a Flash) and public posts.