diff --git a/plugins/generic/enumeration.py b/plugins/generic/enumeration.py index 2b1bd37fc..fc837efbb 100644 --- a/plugins/generic/enumeration.py +++ b/plugins/generic/enumeration.py @@ -749,8 +749,7 @@ class Enumeration: else: return tables - if Backend.getIdentifiedDbms() == DBMS.MYSQL: - conf.db = self.__safeMySQLIdentificatorNaming(conf.db) + conf.db = self.__safeSQLIdentificatorNaming(conf.db) if bruteForce: resumeAvailable = False @@ -933,9 +932,8 @@ class Enumeration: logger.error(errMsg) bruteForce = True - if Backend.getIdentifiedDbms() == DBMS.MYSQL: - conf.tbl = self.__safeMySQLIdentificatorNaming(conf.tbl) - conf.db = self.__safeMySQLIdentificatorNaming(conf.db) + conf.tbl = self.__safeSQLIdentificatorNaming(conf.tbl) + conf.db = self.__safeSQLIdentificatorNaming(conf.db) if bruteForce: resumeAvailable = False @@ -1008,10 +1006,7 @@ class Enumeration: columns = {} for columnData in value: - name = columnData[0] - - if Backend.getIdentifiedDbms() == DBMS.MYSQL: - name = self.__safeMySQLIdentificatorNaming(name) + name = self.__safeSQLIdentificatorNaming(columnData[0]) if len(columnData) == 1: columns[name] = "" @@ -1087,8 +1082,7 @@ class Enumeration: query = agent.limitQuery(index, query, field) column = inject.getValue(query, inband=False, error=False) - if Backend.getIdentifiedDbms() == DBMS.MYSQL: - column = self.__safeMySQLIdentificatorNaming(column) + column = self.__safeSQLIdentificatorNaming(column) if not onlyColNames: if Backend.getIdentifiedDbms() in ( DBMS.MYSQL, DBMS.PGSQL ): @@ -1213,13 +1207,16 @@ class Enumeration: return entries, lengths - def __safeMySQLIdentificatorNaming(self, value): + def __safeSQLIdentificatorNaming(self, value): """ - Returns an safe representation of identificator name for MySQL + Returns an safe representation of SQL identificator name """ retVal = value if isinstance(value, basestring) and not re.match(r"\A[A-Za-z0-9_]+\Z", value): - retVal = "`%s`" % value + if Backend.getIdentifiedDbms() == DBMS.MYSQL: + retVal = "`%s`" % value + elif Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.ORACLE, DBMS.PGSQL): + retVal = "\"%s\"" % value return retVal def dumpTable(self): @@ -1254,9 +1251,8 @@ class Enumeration: rootQuery = queries[Backend.getIdentifiedDbms()].dump_table - if Backend.getIdentifiedDbms() == DBMS.MYSQL: - conf.tbl = self.__safeMySQLIdentificatorNaming(conf.tbl) - conf.db = self.__safeMySQLIdentificatorNaming(conf.db) + conf.tbl = self.__safeSQLIdentificatorNaming(conf.tbl) + conf.db = self.__safeSQLIdentificatorNaming(conf.db) if conf.col: colList = conf.col.split(",")