diff --git a/data_objects/lib/data_objects/connection.rb b/data_objects/lib/data_objects/connection.rb index 7876731c..e487fa40 100644 --- a/data_objects/lib/data_objects/connection.rb +++ b/data_objects/lib/data_objects/connection.rb @@ -18,22 +18,10 @@ def self.new(uri_s) when :java warn 'JNDI URLs (connection strings) are only for use with JRuby' unless RUBY_PLATFORM =~ /java/ - driver = uri.query.delete('scheme') - driver = uri.query.delete('driver') - - conn_uri = uri.to_s.gsub(/\?$/, '') + conn_uri = uri.to_s.gsub(/\?.*$/, '') when :jdbc warn 'JDBC URLs (connection strings) are only for use with JRuby' unless RUBY_PLATFORM =~ /java/ - path = uri.subscheme - driver = if path.split(':').first == 'sqlite' - 'sqlite3' - elsif path.split(':').first == 'postgresql' - 'postgres' - else - path.split(':').first - end - conn_uri = uri_s # NOTE: for now, do not reformat this JDBC connection # string -- or, in other words, do not let # DataObjects::URI#to_s be called -- as it is not @@ -41,18 +29,10 @@ def self.new(uri_s) # java.sql.DriverManager.getConnection to throw a # 'No suitable driver found for...' exception. else - driver = uri.scheme conn_uri = uri end - # Exceptions to how a driver class is determined for a given URI - driver_class = if driver == 'sqlserver' - 'SqlServer' - else - driver.capitalize - end - - clazz = DataObjects.const_get(driver_class)::Connection + clazz = DataObjects.adapter_name(uri)::Connection unless clazz.method_defined? :close if (uri.scheme.to_sym == :java) clazz.class_eval do diff --git a/data_objects/lib/data_objects/transaction.rb b/data_objects/lib/data_objects/transaction.rb index 487b5499..4eb4cbb6 100644 --- a/data_objects/lib/data_objects/transaction.rb +++ b/data_objects/lib/data_objects/transaction.rb @@ -18,7 +18,7 @@ class Transaction # Instantiate the Transaction subclass that's appropriate for this uri scheme def self.create_for_uri(uri) uri = uri.is_a?(String) ? URI::parse(uri) : uri - DataObjects.const_get(uri.scheme.capitalize)::Transaction.new(uri) + DataObjects.adapter_name(uri)::Transaction.new(uri) end # diff --git a/data_objects/lib/data_objects/uri.rb b/data_objects/lib/data_objects/uri.rb index a1aae206..12fc2622 100644 --- a/data_objects/lib/data_objects/uri.rb +++ b/data_objects/lib/data_objects/uri.rb @@ -2,6 +2,56 @@ module DataObjects + def self.adapter_name(uri) + + adapter = uri.scheme + case adapter + when 'java' + adapter = uri.query['adapter'] + unless adapter + # discover the real adapter + jndi_uri = "#{uri.scheme}:#{uri.path}" + context = javax.naming.InitialContext.new + ds= context.lookup(jndi_uri) + conn = ds.getConnection + begin + metadata = conn.getMetaData + driver_name = metadata.getDriverName + + adapter = case driver_name + when /mysql/i then 'mysql' + when /oracle/i then 'oracle' + when /postgres/i then 'postgres' + when /sqlite/i then 'sqlite3' + when /sqlserver|tds|Microsoft SQL/i then 'sqlserver' + else + nil # not supported + end # case + ensure + conn.close + end + end + when 'jdbc' + path = uri.subscheme + adapter = path.split(':').first + end + + # Exceptions to how a adapter class is determined for a given URI + adapter_class = case adapter + when 'sqlserver' + 'SqlServer' + when /sqlite/ + 'Sqlite3' + when /postgres/ + 'Postgres' + else + adapter.capitalize + end + + const_get(adapter_class) + end + + # A DataObjects URI is of the form scheme://user:password@host:port/path#fragment # # The elements are all optional except scheme and path: diff --git a/do_jdbc/src/main/java/data_objects/drivers/AbstractDriverDefinition.java b/do_jdbc/src/main/java/data_objects/drivers/AbstractDriverDefinition.java index 77ed0c89..59c8171f 100644 --- a/do_jdbc/src/main/java/data_objects/drivers/AbstractDriverDefinition.java +++ b/do_jdbc/src/main/java/data_objects/drivers/AbstractDriverDefinition.java @@ -211,7 +211,7 @@ public URI parseConnectionURI(IRubyObject connection_uri) * @param scheme */ protected void verifyScheme(String scheme) { - if (!this.scheme.equals(scheme)) { + if (!this.scheme.equals(scheme) && !this.jdbcScheme.equals(scheme)) { throw new RuntimeException("scheme mismatch, expected: " + this.scheme + " but got: " + scheme); } diff --git a/do_oracle/ext-java/src/main/java/do_oracle/OracleDriverDefinition.java b/do_oracle/ext-java/src/main/java/do_oracle/OracleDriverDefinition.java index b52753a0..3d6c80bc 100644 --- a/do_oracle/ext-java/src/main/java/do_oracle/OracleDriverDefinition.java +++ b/do_oracle/ext-java/src/main/java/do_oracle/OracleDriverDefinition.java @@ -2,6 +2,8 @@ import java.io.IOException; import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.net.URI; import java.sql.Connection; @@ -10,7 +12,6 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; -import oracle.jdbc.OraclePreparedStatement; import oracle.jdbc.OracleTypes; import java.util.Properties; @@ -137,6 +138,121 @@ public void setPreparedStatementParam(PreparedStatement ps, } } + protected static boolean isMethodEquals(Method method, String methodName, Object... parameters) + { + if (! methodName.equals(method.getName())) { + return false; + } + + if (parameters != null) { + Class []parameterTypes = method.getParameterTypes(); + + if (parameterTypes.length != parameters.length) { + return false; + } + + for(int i = 0; i < parameters.length; i++) { + Object parameter = parameters[i]; + Class parameterClass = parameter.getClass(); + Class parameterType = parameterTypes[i]; + + if (parameterType.isPrimitive()) { + + if (parameterType == Integer.TYPE) { + parameterType = Integer.class; + } + if (parameterType == Float.TYPE) { + parameterType = Float.class; + } + if (parameterType == Double.TYPE) { + parameterType = Double.class; + } + if (parameterType == Byte.TYPE) { + parameterType = Byte.class; + } + if (parameterType == Character.TYPE) { + parameterType = Character.class; + } + if (parameterType == Short.TYPE) { + parameterType = Short.class; + } + if (parameterType == Long.TYPE) { + parameterType = Long.class; + } + } + if (! parameterType.isAssignableFrom(parameterClass)) { + return false; + } + } + } + return true; + } + + protected static Object invoke(Object object, String methodName, Object... parameters) + { + Class klass = object.getClass(); + + Method []methods = klass.getMethods(); + + for( Method method : methods) { + + if (isMethodEquals(method, methodName, parameters)) + { + try { + return method.invoke(object, parameters); + } catch (IllegalAccessException e) { + e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. + } catch (InvocationTargetException e) { + e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. + } + } + } + return null; + } + + protected Statement getRealStatement(Statement ps) { + try + { + // if the DataSource is wrapped by DBCP, + // we could be getting a 'org.apache.tomcat.dbcp.dbcp.DelegatingPreparedStatement' + + // in that case, we can get the real PreparedStatement by calling getDelegate + // use reflection because I don't want to introduce dependency to DBCP + + final Class clsDelgatingStatement = Class.forName("org.apache.tomcat.dbcp.dbcp.DelegatingStatement"); + + final Method methodGetDelegate = clsDelgatingStatement.getMethod("getDelegate"); + Object o = ps; + + while( clsDelgatingStatement.isInstance(o)) { + o = methodGetDelegate.invoke(o); + } + if (o instanceof Statement) { + ps = (Statement)o; + } + } catch (ClassNotFoundException e) { + // ignore + } catch (SecurityException e) { + // ignore + } catch (NoSuchMethodException e) { + // ignore + } catch (IllegalArgumentException e) { + // ignore + } catch (IllegalAccessException e) { + // ignore + } catch (InvocationTargetException e) { + Throwable t = e.getCause(); + if (t instanceof Error) { + throw (Error)t; + } + if (t instanceof RuntimeException) { + throw (RuntimeException)t; + } + } + + return ps; + } + /** * * @param sqlText @@ -147,11 +263,12 @@ public void setPreparedStatementParam(PreparedStatement ps, */ @Override public boolean registerPreparedStatementReturnParam(String sqlText, PreparedStatement ps, int idx) throws SQLException { - OraclePreparedStatement ops = (OraclePreparedStatement) ps; + PreparedStatement ops = (PreparedStatement)getRealStatement(ps); + Pattern p = Pattern.compile("^\\s*INSERT.+RETURNING.+INTO\\s+", Pattern.CASE_INSENSITIVE); Matcher m = p.matcher(sqlText); if (m.find()) { - ops.registerReturnParameter(idx, Types.BIGINT); + invoke(ps, "registerReturnParameter", idx, Types.BIGINT); return true; } return false; @@ -165,8 +282,8 @@ public boolean registerPreparedStatementReturnParam(String sqlText, PreparedStat */ @Override public long getPreparedStatementReturnParam(PreparedStatement ps) throws SQLException { - OraclePreparedStatement ops = (OraclePreparedStatement) ps; - ResultSet rs = ops.getReturnResultSet(); + PreparedStatement ops = (PreparedStatement)getRealStatement(ps); + ResultSet rs = (ResultSet) invoke(ps, "getReturnResultSet"); try { if (rs.next()) { // Assuming that primary key will not be larger as long max value @@ -293,6 +410,7 @@ public void afterConnectionCallback(IRubyObject doConn, Connection conn, Map