-- ------------------------------------------------------------------------------ -- -- TradeSkillMaster -- -- https://tradeskillmaster.com -- -- All Rights Reserved - Detailed license information included with addon. -- -- ------------------------------------------------------------------------------ -- --- Database Query Class. -- This class represents a database query which is used for reading data out of a @{Database} in a structured and -- efficient manner. -- @classmod DatabaseQuery local _, TSM = ... local Query = TSM.Init("Util.DatabaseClasses.Query") local Constants = TSM.Include("Util.DatabaseClasses.Constants") local Util = TSM.Include("Util.DatabaseClasses.Util") local QueryResultRow = TSM.Include("Util.DatabaseClasses.QueryResultRow") local QueryClause = TSM.Include("Util.DatabaseClasses.QueryClause") local ObjectPool = TSM.Include("Util.ObjectPool") local TempTable = TSM.Include("Util.TempTable") local Table = TSM.Include("Util.Table") local Math = TSM.Include("Util.Math") local LibTSMClass = TSM.Include("LibTSMClass") local DatabaseQuery = LibTSMClass.DefineClass("DatabaseQuery") local private = { objectPool = nil, } -- ============================================================================ -- Module Loading -- ============================================================================ Query:OnModuleLoad(function() private.objectPool = ObjectPool.New("DATABASE_QUERIES", DatabaseQuery, 1) end) -- ============================================================================ -- Module Functions -- ============================================================================ function Query.Get(db) local clause = private.objectPool:Get() clause:_Acquire(db) return clause end -- ============================================================================ -- Class Meta Methods -- ============================================================================ function DatabaseQuery.__init(self) self._db = nil self._rootClause = nil self._currentClause = nil self._orderBy = {} self._orderByAscending = {} self._distinct = nil self._updateCallback = nil self._updateCallbackContext = nil self._updatesPaused = 0 self._queuedUpdate = false self._select = {} self._iteratorState = "IDLE" self._result = {} self._resultRowLookup = {} self._iterDistinctUsed = {} self._tempResultRow = nil self._tempVirtualResultRow = nil self._autoRelease = false self._resultIsStale = false self._joinTypes = {} self._joinDBs = {} self._joinFields = {} self._aggregateJoinFields = {} self._aggregateJoinQueries = {} self._virtualFieldFunc = {} self._virtualFieldArgField = {} self._virtualFieldType = {} self._virtualFieldDefault = {} self._genericSortWrapper = function(a, b) return private.DatabaseQuerySortGeneric(self, a, b) end self._singleSortWrapper = function(a, b) return private.DatabaseQuerySortSingle(self, a, b, self._orderByAscending[1]) end self._secondarySortWrapper = function(a, b) return private.DatabaseQuerySortSingle(self, a, b, self._orderByAscending[2]) end self._sortValueCache = {} self._resultDependencies = {} end function DatabaseQuery._Acquire(self, db) self._db = db self._db:_RegisterQuery(self) -- implicit root AND clause self._rootClause = QueryClause.Get(self) :And() self._currentClause = self._rootClause self._tempResultRow = QueryResultRow.Get() self._tempResultRow:_Acquire(self._db, self) end function DatabaseQuery._Release(self) assert(self._iteratorState == "IDLE") -- remove from the database self._db:_RemoveQuery(self) self._db = nil self._rootClause:_Release() self._rootClause = nil self._currentClause = nil self._updateCallback = nil self._updateCallbackContext = nil self._updatesPaused = 0 self._queuedUpdate = false wipe(self._iterDistinctUsed) self._tempResultRow:Release() self._tempResultRow = nil if self._tempVirtualResultRow then self._tempVirtualResultRow:Release() self._tempVirtualResultRow = nil end self._autoRelease = false self:_WipeResults() self:ResetOrderBy() self:ResetDistinct() self:ResetSelect() self:ResetJoins() self:ResetVirtualFields() self._resultIsStale = false wipe(self._resultDependencies) end -- ============================================================================ -- Public Class Methods -- ============================================================================ --- Releases the database query. -- The database query object will be recycled and must not be accessed after calling this method. -- @tparam DatabaseQuery self The database query object -- @tparam[opt=false] boolean abortIterator Abort any in-progress iterator function DatabaseQuery.Release(self, abortIterator) if abortIterator then self._iteratorState = "IDLE" end self:_Release() private.objectPool:Recycle(self) end --- Adds a virtual field to the query. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the new virtual field -- @tparam string fieldType The type of the virtual field -- @tparam function func A function which takes a row and returns the value of the virtual field -- @tparam[opt=nil] string argField The field to pass into the function (otherwise passes the entire row) -- @param[opt=nil] defaultValue The default value to use if the function returns nil -- @treturn DatabaseQuery The database query object function DatabaseQuery.VirtualField(self, field, fieldType, func, argField, defaultValue) if self:_GetFieldType(field) or self._virtualFieldFunc[field] then error("Field already exists: "..tostring(field)) elseif type(func) ~= "function" then error("Invalid func: "..tostring(func)) elseif fieldType ~= "number" and fieldType ~= "string" and fieldType ~= "boolean" then error("Field type must be string, number, or boolean") elseif argField and not self:_GetFieldType(argField) then error("Arg field doesn't exist: "..tostring(argField)) elseif defaultValue ~= nil and type(defaultValue) ~= fieldType then error("Invalid defaultValue type: "..tostring(defaultValue)) end self._virtualFieldFunc[field] = func self._virtualFieldArgField[field] = argField self._virtualFieldType[field] = fieldType self._virtualFieldDefault[field] = defaultValue self._resultIsStale = true return self end --- Where a field equals a value. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @param value The value to compare to -- @tparam[opt=nil] string otherField The name of the other field to compare with -- @treturn DatabaseQuery The database query object function DatabaseQuery.Equal(self, field, value, otherField) if value == Constants.OTHER_FIELD_QUERY_PARAM then local fieldType = self:_GetFieldType(field) assert(fieldType and fieldType == self:_GetFieldType(otherField)) elseif value ~= Constants.BOUND_QUERY_PARAM then assert(self:_GetFieldType(field) == type(value)) end self:_NewClause() :Equal(field, value, otherField) return self end --- Where a field does not equals a value. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @param value The value to compare to -- @tparam[opt=nil] string otherField The name of the other field to compare with -- @treturn DatabaseQuery The database query object function DatabaseQuery.NotEqual(self, field, value, otherField) if value == Constants.OTHER_FIELD_QUERY_PARAM then local fieldType = self:_GetFieldType(field) assert(fieldType and fieldType == self:_GetFieldType(otherField)) elseif value ~= Constants.BOUND_QUERY_PARAM then assert(self:_GetFieldType(field) == type(value)) end self:_NewClause() :NotEqual(field, value, otherField) return self end --- Where a field is less than a value. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @param value The value to compare to -- @tparam[opt=nil] string otherField The name of the other field to compare with -- @treturn DatabaseQuery The database query object function DatabaseQuery.LessThan(self, field, value, otherField) if value == Constants.OTHER_FIELD_QUERY_PARAM then local fieldType = self:_GetFieldType(field) assert(fieldType and fieldType == self:_GetFieldType(otherField)) elseif value ~= Constants.BOUND_QUERY_PARAM then assert(self:_GetFieldType(field) == type(value)) end self:_NewClause() :LessThan(field, value, otherField) return self end --- Where a field is less than or equal to a value. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @param value The value to compare to -- @tparam[opt=nil] string otherField The name of the other field to compare with -- @treturn DatabaseQuery The database query object function DatabaseQuery.LessThanOrEqual(self, field, value, otherField) if value == Constants.OTHER_FIELD_QUERY_PARAM then local fieldType = self:_GetFieldType(field) assert(fieldType and fieldType == self:_GetFieldType(otherField)) elseif value ~= Constants.BOUND_QUERY_PARAM then assert(self:_GetFieldType(field) == type(value)) end self:_NewClause() :LessThanOrEqual(field, value, otherField) return self end --- Where a field is greater than a value. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @param value The value to compare to -- @tparam[opt=nil] string otherField The name of the other field to compare with -- @treturn DatabaseQuery The database query object function DatabaseQuery.GreaterThan(self, field, value, otherField) if value == Constants.OTHER_FIELD_QUERY_PARAM then local fieldType = self:_GetFieldType(field) assert(fieldType and fieldType == self:_GetFieldType(otherField)) elseif value ~= Constants.BOUND_QUERY_PARAM then assert(self:_GetFieldType(field) == type(value)) end self:_NewClause() :GreaterThan(field, value, otherField) return self end --- Where a field is greater than or equal to a value. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @param value The value to compare to -- @tparam[opt=nil] string otherField The name of the other field to compare with -- @treturn DatabaseQuery The database query object function DatabaseQuery.GreaterThanOrEqual(self, field, value, otherField) if value == Constants.OTHER_FIELD_QUERY_PARAM then local fieldType = self:_GetFieldType(field) assert(fieldType and fieldType == self:_GetFieldType(otherField)) elseif value ~= Constants.BOUND_QUERY_PARAM then assert(self:_GetFieldType(field) == type(value)) end self:_NewClause() :GreaterThanOrEqual(field, value, otherField) return self end --- Where a string field matches a pattern. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @tparam string value The pattern to match -- @treturn DatabaseQuery The database query object function DatabaseQuery.Matches(self, field, value) assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values") assert(self:_GetFieldType(field) == "string" and type(value) == "string") self:_NewClause() :Matches(field, strlower(value)) return self end --- Where a string field contains a substring. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @tparam string value The substring to match -- @treturn DatabaseQuery The database query object function DatabaseQuery.Contains(self, field, value) assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values") assert(self:_GetFieldType(field) == "string" and type(value) == "string") self:_NewClause() :Contains(field, strlower(value)) return self end --- Where a string field starts with a substring. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @tparam string value The substring to match -- @treturn DatabaseQuery The database query object function DatabaseQuery.StartsWith(self, field, value) assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values") assert(self:_GetFieldType(field) == "string" and type(value) == "string") self:_NewClause() :StartsWith(field, strlower(value)) return self end --- Where a foreign field (obtained via a left join) is nil. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @treturn DatabaseQuery The database query object function DatabaseQuery.IsNil(self, field) assert(self:_GetJoinType(field) == "LEFT", "Must be a left join") self:_NewClause() :IsNil(field) return self end --- Where a foreign field (obtained via a left join) is not nil. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @treturn DatabaseQuery The database query object function DatabaseQuery.IsNotNil(self, field) assert(self:_GetJoinType(field) == "LEFT", "Must be a left join") self:_NewClause() :IsNotNil(field) return self end --- A custom query clause. -- @tparam DatabaseQuery self The database query object -- @tparam function func The function which gets passed the row being evaulated and returns true/false if the query -- should include it -- @param[opt] arg An argument to pass to the function -- @treturn DatabaseQuery The database query object function DatabaseQuery.Custom(self, func, arg) assert(type(func) == "function") self:_NewClause() :Custom(func, arg) return self end --- Where the hash of a row equals a value. -- @tparam DatabaseQuery self The database query object -- @tparam table fields An ordered list of fields to hash -- @tparam number value The hash value to compare to -- @treturn DatabaseQuery The database query object function DatabaseQuery.HashEqual(self, fields, value) assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values") assert(type(fields) == "table") for _, field in ipairs(fields) do local fieldType = self:_GetFieldType(field) if not fieldType then error(format("Field %s doesn't exist", tostring(field))) elseif fieldType ~= "number" and fieldType ~= "string" then error(format("Cannot hash field of type %s", fieldType)) end end self:_NewClause() :HashEqual(fields, value) return self end --- Where a field exists as a key within a table -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @tparam table value The table to check against -- @treturn DatabaseQuery The database query object function DatabaseQuery.InTable(self, field, value) assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values") assert(type(value) == "table") self:_NewClause() :InTable(field, value) return self end --- Where a field does not exists as a key within a table -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field -- @tparam table value The table to check against -- @treturn DatabaseQuery The database query object function DatabaseQuery.NotInTable(self, field, value) assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values") assert(type(value) == "table") self:_NewClause() :NotInTable(field, value) return self end --- Starts a nested AND clause. -- All of the clauses following this (until the matching @{DatabaseQuery.End}) must be true for the OR clause to be true. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.And(self) self._currentClause = self:_NewClause() :And() return self end --- Starts a nested OR clause. -- At least one of the clauses following this (until the matching @{DatabaseQuery.End}) must be true for the OR clause -- to be true. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.Or(self) self._currentClause = self:_NewClause() :Or() return self end --- Ends a nested AND/OR clause. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.End(self) assert(self._currentClause ~= self._rootClause, "No current clause to end") self._currentClause = self._currentClause:_GetParent() assert(self._currentClause) return self end --- Performs a left join with another table. -- @tparam DatabaseQuery self The database query object -- @tparam DatabaseTable db The database table to join with -- @tparam string field The field to join on -- @treturn DatabaseQuery The database query object function DatabaseQuery.LeftJoin(self, db, field) self:_JoinHelper(db, field, "LEFT") return self end --- Performs an inner join with another table. -- @tparam DatabaseQuery self The database query object -- @tparam DatabaseTable db The database table to join with -- @tparam string field The field to join on -- @treturn DatabaseQuery The database query object function DatabaseQuery.InnerJoin(self, db, field) self:_JoinHelper(db, field, "INNER") return self end --- Performs an aggregate join with another table with a summed field. -- @tparam DatabaseQuery self The database query object -- @tparam string db The database to join with -- @tparam string field The name of the field in the other table to join on -- @tparam string sumField The name of the field in the other table to sum -- @treturn DatabaseQuery The database query object function DatabaseQuery.AggregateJoinSummed(self, db, field, sumField) local query = db:NewQuery() :Equal(field, Constants.BOUND_QUERY_PARAM) self:_JoinHelper(db, field, "AGGREGATE_SUM", sumField, query) return self end --- Order the results by a field. -- This may be called multiple times to provide additional ordering constraints. The priority of the ordering will be -- descending as this method is called additional times (meaning the first OrderBy will have highest priority). -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field to order by -- @tparam boolean ascending Whether to order in ascending order (descending otherwise) -- @treturn DatabaseQuery The database query object function DatabaseQuery.OrderBy(self, field, ascending) assert(ascending == true or ascending == false) local fieldType = self:_GetFieldType(field) if not fieldType then error(format("Field %s doesn't exist", tostring(field))) elseif fieldType ~= "number" and fieldType ~= "string" and fieldType ~= "boolean" then error(format("Cannot order by field of type %s", tostring(fieldType))) end tinsert(self._orderBy, field) tinsert(self._orderByAscending, ascending) self._resultIsStale = true return self end --- Only return distinct results based on a field. -- This method can be used to ensure that only the first row for each distinct value of the field is returned. -- @tparam DatabaseQuery self The database query object -- @tparam string field The field to ensure is distinct in the results -- @treturn DatabaseQuery The database query object function DatabaseQuery.Distinct(self, field) assert(self:_GetFieldType(field), format("Field %s doesn't exist within local DB", tostring(field))) self._distinct = field self._resultIsStale = true return self end --- Select specific fields in the result. -- @tparam DatabaseQuery self The database query object -- @tparam vararg ... The fields to select -- @treturn DatabaseQuery The database query object function DatabaseQuery.Select(self, ...) assert(#self._select == 0) local numFields = select("#", ...) assert(numFields > 0, "Must select at least 1 field") -- DatabaseRow.GetFields() only supports 10 fields, so we can only support 10 here as well assert(numFields <= 10, "Select() only supports up to 10 fields") for i = 1, numFields do local field = select(i, ...) tinsert(self._select, field) end self._resultIsStale = true return self end --- Binds parameters to a prepared query. -- The number of arguments should match the number of Constants.BOUND_QUERY_PARAM values in the query's clauses. -- @tparam DatabaseQuery self The database query object -- @tparam vararg ... The fields to select -- @treturn DatabaseQuery The database query object function DatabaseQuery.BindParams(self, ...) local numFields = select("#", ...) assert(self._rootClause:_BindParams(...) == numFields, "Invalid number of bound parameters") self._resultIsStale = true return self end --- Set an update callback. -- This callback gets called whenever any rows in the underlying database change. -- @tparam DatabaseQuery self The database query object -- @tparam function func The callback function which is called with (self, changedUUID, context) -- @param[opt=nil] context A context argument which is passed as the third argument to the callback function -- @treturn DatabaseQuery The database query object function DatabaseQuery.SetUpdateCallback(self, func, context) self._updateCallback = func self._updateCallbackContext = context return self end --- Pauses or unpauses callbacks for query updates. -- @tparam DatabaseQuery self The database query object -- @tparam boolean paused Whether or not updates should be paused -- @treturn DatabaseQuery The database query object function DatabaseQuery.SetUpdatesPaused(self, paused) self._updatesPaused = self._updatesPaused + (paused and 1 or -1) assert(self._updatesPaused >= 0) if self._updatesPaused == 0 and self._queuedUpdate then self:_DoUpdateCallback() end return self end --- Results iterator. -- Note that the iterator must run to completion (don't use `break` or `return` to escape it early). -- @tparam DatabaseQuery self The database query object -- @tparam boolean canAbort Allow the iterator to be aborted if the underlying data is updated which must -- be handled by the caller by calling `IsIteratorAborted()` at the end of each iteration loop -- @return An iterator for the results of the query function DatabaseQuery.Iterator(self, canAbort) self:_Execute() assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause") assert(self._iteratorState == "IDLE") assert(not canAbort or not self._updateCallback) self._iteratorState = canAbort and "IN_PROGRESS_CAN_ABORT" or "IN_PROGRESS" self._autoRelease = false return private.QueryResultIterator, self, 0 end --- Iterates through the results as uuids. -- @tparam DatabaseQuery self The database query object -- @return An iterator for the results of the query as UUIDs function DatabaseQuery.UUIDIterator(self) self:_Execute() assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause") assert(self._iteratorState == "IDLE") self._iteratorState = "IN_PROGRESS" self._autoRelease = false return private.QueryResultAsUUIDIterator, self, 0 end --- Results iterator which releases upon completion. -- Note that the iterator must run to completion (don't use `break` or `return` to escape it early). -- @tparam DatabaseQuery self The database query object -- @return An iterator for the results of the query function DatabaseQuery.IteratorAndRelease(self) self:_Execute() assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause") assert(self._iteratorState == "IDLE") self._iteratorState = "IN_PROGRESS" self._autoRelease = true return private.QueryResultIterator, self, 0 end --- Check if the abortable iterator has been aborted. -- @tparam DatabaseQuery self The database query object -- @treturn boolean Whether or not the iterator has been aborted function DatabaseQuery.IsIteratorAborted(self) if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then return false elseif self._iteratorState == "PENDING_ABORT" then self._iteratorState = "ABORTED" return true else error("Invalid iterator state: "..tostring(self._iteratorState)) end end --- Populates a table with the results. -- The query must have a select clause with at least one or two fields. In the former case, the table will be populated -- as a list, and in the latter case, the first field must be unique in the results, and will be used as the key for the -- table with the second field being the value. -- @tparam DatabaseQuery self The database query object -- @tparam table tbl The table to store the result in function DatabaseQuery.AsTable(self, tbl) self:_Execute() if #self._select == 1 then local field = unpack(self._select) for _, uuid in ipairs(self._result) do tinsert(tbl, self:_GetResultRowData(uuid, field)) end elseif #self._select == 2 then local field1, field2 = unpack(self._select) for _, uuid in ipairs(self._result) do local key = self:_GetResultRowData(uuid, field1) if key == nil or tbl[key] then error("Key is nil or not distinct") end tbl[key] = self:_GetResultRowData(uuid, field2) end else error("Invalid select clause") end return self end --- Get the number of resulting rows. -- @tparam DatabaseQuery self The database query object -- @treturn number The number of rows function DatabaseQuery.Count(self) self:_Execute() return #self._result end --- Get the number of resulting rows and release. -- @tparam DatabaseQuery self The database query object -- @treturn number The number of rows function DatabaseQuery.CountAndRelease(self) self:_Execute() local count = #self._result self:Release() return count end --- Get a single result. -- This method will assert that there is exactly one result from the query and return it. -- @tparam DatabaseQuery self The database query object -- @return The result row or the selected fields function DatabaseQuery.GetSingleResult(self) self:_Execute() assert(self:Count() == 1) return self:GetFirstResult() end --- Get a single result and release. -- This method will assert that there is exactly one result from the query and return it. -- @tparam DatabaseQuery self The database query object -- @return The result row or the selected fields function DatabaseQuery.GetSingleResultAndRelease(self) assert(#self._select > 0) local result = self:GetSingleResult() self:Release() return result end --- Get the first result. -- Note that this method internally iterates over all the results. -- @tparam DatabaseQuery self The database query object -- @return The result row or the selected fields function DatabaseQuery.GetFirstResult(self) self:_Execute() assert(self._iteratorState == "IDLE") if self:Count() == 0 then return end local uuid = self._result[1] if not self._resultRowLookup[uuid] then self:_CreateResultRow(uuid) end local row = self._resultRowLookup[uuid] if #self._select > 0 then return row:GetFields(unpack(self._select)) else return row end end --- Get the first result and release. -- Note that this method internally iterates over all the results. -- @tparam DatabaseQuery self The database query object -- @return The result row or the selected fields function DatabaseQuery.GetFirstResultAndRelease(self) self:_Execute() assert(self._iteratorState == "IDLE") if self:Count() == 0 then self:Release() return end local uuid = self._result[1] if not self._resultRowLookup[uuid] then self:_CreateResultRow(uuid) end local row = self._resultRowLookup[uuid] if #self._select > 0 then return self:_PassThroughReleaseHelper(row:GetFields(unpack(self._select))) else row = row:Clone() self:Release() return row end end --- Gets the minimum value of a specific field within the query results. -- @tparam DatabaseQuery self The database query object -- @tparam string field The field within the results -- @treturn ?number The minimum value or nil if there are no results function DatabaseQuery.Min(self, field) self:_Execute() local result = nil for _, uuid in ipairs(self._result) do local value = self:_GetResultRowData(uuid, field) result = min(result or math.huge, value) end return result end --- Gets the maximum value of a specific field within the query results. -- @tparam DatabaseQuery self The database query object -- @tparam string field The field within the results -- @treturn ?number The maximum value or nil if there are no results function DatabaseQuery.Max(self, field) self:_Execute() local result = nil for _, uuid in ipairs(self._result) do local value = self:_GetResultRowData(uuid, field) result = max(result or -math.huge, value) end return result end --- Gets the summed value of a specific field within the query results. -- @tparam DatabaseQuery self The database query object -- @tparam string field The field within the results -- @treturn ?number The summed value or nil if there are no results function DatabaseQuery.Sum(self, field) self:_Execute() local result = 0 for _, uuid in ipairs(self._result) do result = result + self:_GetResultRowData(uuid, field) end return result end --- Gets the summed value of a specific field for each group within the query results. -- @tparam DatabaseQuery self The database query object -- @tparam string groupField The field to group by -- @tparam string sumField The field to sum -- @tparam table result The results table function DatabaseQuery.GroupedSum(self, groupField, sumField, result) self:_Execute() for _, uuid in ipairs(self._result) do local group = self:_GetResultRowData(uuid, groupField) local value = self:_GetResultRowData(uuid, sumField) result[group] = (result[group] or 0) + value end end --- Gets the summed value of a specific field within the query results and releases the query. -- @tparam DatabaseQuery self The database query object -- @tparam string field The field within the results -- @treturn number The summed value function DatabaseQuery.SumAndRelease(self, field) self:_Execute() local result = 0 for _, uuid in ipairs(self._result) do result = result + self:_GetResultRowData(uuid, field) end self:Release() return result end --- Gets the average value of a specific field within the query results. -- @tparam DatabaseQuery self The database query object -- @tparam string field The field within the results -- @treturn ?number The average value or nil if there are no results function DatabaseQuery.Avg(self, field) local sum = self:Sum(field) local num = self:Count() return num > 0 and (sum / num) or nil end --- Gets the sum of the products of two fields within the query results. -- @tparam DatabaseQuery self The database query object -- @tparam string field1 The first field within the results -- @tparam string field2 The second field within the results -- @treturn number The summed value function DatabaseQuery.SumOfProduct(self, field1, field2) self:_Execute() local result = 0 for _, uuid in ipairs(self._result) do local value1 = self:_GetResultRowData(uuid, field1) local value2 = self:_GetResultRowData(uuid, field2) result = result + value1 * value2 end return result end --- Joins the string values of a field with a given separator. -- @tparam DatabaseQuery self The database query object -- @tparam string field The field within the results -- @tparam string sep The separator (can be any number of characters, including an empty string) -- @treturn string The joined string function DatabaseQuery.JoinedString(self, field, sep) self:_Execute() local parts = TempTable.Acquire() for _, uuid in ipairs(self._result) do tinsert(parts, self:_GetResultRowData(uuid, field)) end local result = table.concat(parts, sep) TempTable.Release(parts) return result end --- Calculates the hash of the query results. -- Note that either `fields` must be specified or the query must have a select colum with at most 2 fields. -- @tparam DatabaseQuery self The database query object -- @tparam[opt=nil] table fields The fields from each row to hash (ottherwise uses the selected fields) -- @treturn ?number The hash value or nil if there are no results function DatabaseQuery.Hash(self, fields) self:_Execute() local result = nil if fields then for _, uuid in ipairs(self._result) do for _, field in ipairs(fields) do result = Math.CalculateHash(self:_GetResultRowData(uuid, field), result) end end else local keyField, valueField, extra = unpack(self._select) assert(keyField and not extra) local hashContext = TempTable.Acquire() for _, uuid in ipairs(self._result) do tinsert(hashContext, self:_GetResultRowData(uuid, keyField)) if valueField then tinsert(hashContext, self:_GetResultRowData(uuid, valueField)) end end Table.Sort(hashContext) for _, value in ipairs(hashContext) do result = Math.CalculateHash(value, result) end TempTable.Release(hashContext) end return result end --- Calculates the hash of the query results, grouping by a field. -- @tparam DatabaseQuery self The database query object -- @tparam table fields The fields from each row to hash -- @tparam string groupField The field to group by -- @tparam table result The result table function DatabaseQuery.GroupedHash(self, fields, groupField, result) self:_Execute() for i = 1, #self._result do local uuid = self._result[i] local groupValue = self:_GetResultRowData(uuid, groupField) local rowHash = nil for j = 1, #fields do rowHash = Math.CalculateHash(self:_GetResultRowData(uuid, fields[j]), rowHash) end result[groupValue] = Math.CalculateHash(rowHash, result[groupValue]) end end --- Calculates the hash of the query results and release. -- Note that either `fields` must be specified or the query must have a select colum with at most 2 fields. -- @tparam DatabaseQuery self The database query object -- @tparam[opt=nil] table fields The fields from each row to hash (ottherwise uses the selected fields) -- @treturn ?number The hash value or nil if there are no results function DatabaseQuery.HashAndRelease(self, fields) local result = self:Hash(fields) self:Release() return result end --- Deletes all the result rows from the database and releases the query. -- @tparam DatabaseQuery self The database query object -- @treturn ?number The number of rows deleted (equal to `:Count()`) function DatabaseQuery.DeleteAndRelease(self) local count = self:Count() self._db:BulkDelete(self._result) self:Release() return count end --- Resets the database query. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.Reset(self) self:ResetDistinct() self:ResetSelect() self:ResetOrderBy() self:ResetJoins() self:ResetFilters() self:ResetVirtualFields() self:_WipeResults() self._resultIsStale = true return self end --- Resets any virtual fields added to the database query. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.ResetVirtualFields(self) wipe(self._virtualFieldFunc) wipe(self._virtualFieldArgField) wipe(self._virtualFieldType) wipe(self._virtualFieldDefault) self._resultIsStale = true return self end --- Resets any filtering clauses of the database query. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.ResetFilters(self) self._rootClause:_Release() self._rootClause = QueryClause.Get(self) :And() self._currentClause = self._rootClause self._resultIsStale = true return self end --- Resets any ordering clauses of the database query. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.ResetOrderBy(self) wipe(self._orderBy) wipe(self._orderByAscending) self._resultIsStale = true return self end --- Resets any joins of the database query. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.ResetJoins(self) for _, db in ipairs(self._joinDBs) do db:_RemoveQuery(self) end wipe(self._joinTypes) wipe(self._joinDBs) wipe(self._joinFields) wipe(self._aggregateJoinFields) for _, query in ipairs(self._aggregateJoinQueries) do if query then query:Release() end end wipe(self._aggregateJoinQueries) self._resultIsStale = true return self end --- Resets any distinct clauses of the database query. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.ResetDistinct(self) self._distinct = nil self._resultIsStale = true return self end --- Resets any select clauses of the database query. -- @tparam DatabaseQuery self The database query object -- @treturn DatabaseQuery The database query object function DatabaseQuery.ResetSelect(self) wipe(self._select) self._resultIsStale = true return self end --- Gets info on a specific order by clause. -- @tparam DatabaseQuery self The database query object -- @tparam number index The index of the order by clause -- @treturn ?string The field name -- @treturn ?boolean Whether or not the sort is ascending function DatabaseQuery.GetOrderBy(self, index) assert(self._orderBy[index]) return self._orderBy[index], self._orderByAscending[index] end --- Gets info on the last order by clause. -- @tparam DatabaseQuery self The database query object -- @treturn ?string The field name -- @treturn ?boolean Whether or not the sort is ascending function DatabaseQuery.GetLastOrderBy(self) return self._orderBy[#self._orderBy], self._orderByAscending[#self._orderByAscending] end --- Updates the last order by clause. -- @tparam DatabaseQuery self The database query object -- @tparam string field The name of the field to order by -- @tparam boolean ascending Whether to order in ascending order (descending otherwise) -- @treturn DatabaseQuery The database query object function DatabaseQuery.UpdateLastOrderBy(self, field, ascending) assert(#self._orderBy > 0) tremove(self._orderBy) tremove(self._orderByAscending) self:OrderBy(field, ascending) return self end --- Get a result row by its UUID. -- @tparam DatabaseQuery self The database query object -- @tparam number uuid The UUID of the row to get -- @return QueryResultRow The result row name function DatabaseQuery.GetResultRowByUUID(self, uuid) if not self._resultRowLookup[uuid] then self:_CreateResultRow(uuid) end return self._resultRowLookup[uuid] end -- ============================================================================ -- Private Class Methods -- ============================================================================ function DatabaseQuery._GetJoinType(self, field) for i, db in ipairs(self._joinDBs) do if db:_GetFieldType(field) then return self._joinTypes[i] end end end function DatabaseQuery._GetFieldType(self, field) local fieldType = self._virtualFieldType[field] or self._db:_GetFieldType(field) if fieldType then return fieldType end for i, db in ipairs(self._joinDBs) do if field == self._aggregateJoinFields[i] then if self._joinTypes[i] == "AGGREGATE_SUM" then fieldType = "number" else error("Unknown aggregate join type: "..tostring(self._joinTypes[i])) end else fieldType = db:_GetFieldType(field) end if fieldType then return fieldType end end end function DatabaseQuery._MarkResultStale(self, changedFields) assert(self._iteratorState == "IDLE" or self._iteratorState == "IN_PROGRESS_CAN_ABORT" or self._iteratorState == "PENDING_ABORT") if self._resultIsStale then -- already marked stale return end if self._resultDependencies._all or not changedFields then -- either the result depends on all fields or we weren't given a table of changed fields self._resultIsStale = true if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then self._iteratorState = "PENDING_ABORT" end return end -- check if any of the fields our result is based on changed for field in pairs(changedFields) do if self._resultDependencies[field] then self._resultIsStale = true if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then self._iteratorState = "PENDING_ABORT" end return end end -- clear the cached values for the changed fields for _, row in pairs(self._resultRowLookup) do if row ~= false then for field in pairs(changedFields) do rawset(row, field, nil) end end end if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then self._iteratorState = "PENDING_ABORT" end end function DatabaseQuery._DoUpdateCallback(self, uuid) if not self._updateCallback then assert(self._iteratorState == "IDLE" or self._iteratorState == "PENDING_ABORT") return end -- can't have an update callback on an abortable iterator assert(self._iteratorState == "IDLE") if self._updatesPaused > 0 then self._queuedUpdate = true else self._queuedUpdate = false if self._resultIsStale or not uuid then self:_updateCallback(nil, self._updateCallbackContext) elseif self._db:_ContainsUUID(uuid) then self:_updateCallback(uuid, self._updateCallbackContext) else -- the UUID is from a joined DB, so see if we can easily translate it to a local UUID local localUUID = nil for i = 1, #self._joinDBs do local joinDB = self._joinDBs[i] if not self._aggregateJoinFields[i] and joinDB:_ContainsUUID(uuid) then if localUUID then -- found more than once, so bail localUUID = nil break end local joinField = self._joinFields[i] local joinValue = joinDB:GetRowFieldByUUID(uuid, joinField) if self._db:_IsUnique(joinField) then localUUID = self._db:_GetUniqueRow(joinField, joinValue) elseif self._db:_IsIndex(joinField) then local lowIndex, highIndex = self._db:_GetIndexListMatchingIndexRange(joinField, Util.ToIndexValue(joinValue)) if not lowIndex or not highIndex or lowIndex ~= highIndex then -- can't use this index to find a single local UUID break end localUUID = self._db:_GetAllRowsByIndex(joinField)[lowIndex] end end end self:_updateCallback(localUUID, self._updateCallbackContext) end end end function DatabaseQuery._NewClause(self) self._resultIsStale = true local newClause = QueryClause.Get(self, self._currentClause) self._currentClause:_InsertSubClause(newClause) return newClause end function DatabaseQuery._WipeResults(self) for _, row in pairs(self._resultRowLookup) do if row ~= false then row:Release() end end wipe(self._result) wipe(self._resultRowLookup) end function DatabaseQuery._Execute(self, force) if not self._resultIsStale and not force then return end assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause") assert(self._iteratorState == "IDLE") assert(not next(self._iterDistinctUsed)) -- clear the current result self:_WipeResults() -- get all the rows which we need to iterate over local firstOrderBy = self._orderBy[1] local skipFirstOrderBy = false local sortNeeded = firstOrderBy and true or false local indexType, indexField, indexArg1, indexArg2, indexArg3 = self:_GetQueryIndexInfo() self._result._queryOptimizationResult = indexType self._result._queryOptimizationField = indexField if indexType == "EMPTY" then sortNeeded = false elseif indexType == "UNIQUE" then -- we are looking for a unique row local indexValue = indexArg1 local uuid = self._db:_GetUniqueRow(indexField, indexValue) if uuid and self:_ResultShouldIncludeRow(uuid, false, #self._joinDBs, self._distinct) then tinsert(self._result, uuid) self._resultRowLookup[uuid] = false end sortNeeded = false elseif indexType == "INDEX" then -- we're querying on an index, so use that index to populate the result local firstIndex, lastIndex, isStrict = indexArg1, indexArg2, indexArg3 local isAscending = true if firstOrderBy and indexField == firstOrderBy then -- we're also ordering by this field so can skip the first OrderBy field self._result._queryOptimizationResult = "INDEX_AND_ORDER_BY" skipFirstOrderBy = true sortNeeded = #self._orderBy > 1 isAscending = self._orderByAscending[1] end local indexList = self._db:_GetAllRowsByIndex(indexField) self:_AddResultRowsFromIndex(indexList, isStrict, firstIndex, lastIndex, isAscending, indexField) elseif indexType == "NONE" then if firstOrderBy and self._db:_IsIndex(firstOrderBy) then -- we're ordering on an index, so use that index to iterate through all the rows in order to skip the first OrderBy field self._result._queryOptimizationResult = "ORDER_BY" self._result._queryOptimizationField = firstOrderBy skipFirstOrderBy = true sortNeeded = #self._orderBy > 1 local isAscending = self._orderByAscending[1] local indexList = self._db:_GetAllRowsByIndex(firstOrderBy) self:_AddResultRowsFromIndex(indexList, false, 1, #indexList, isAscending) else -- no optimizations self:_AddResultRowsCheckAll() end elseif indexType == "TRIGRAM" then local indexValue = indexArg1 local uuids = TempTable.Acquire() self._db:_GetTrigramIndexMatchingRows(indexValue, uuids) self:_AddResultRowsFromIndex(uuids, false, 1, #uuids, true) TempTable.Release(uuids) else error("Invalid index type: "..tostring(indexType)) end wipe(self._iterDistinctUsed) -- sort the results if necessary if sortNeeded then if #self._orderBy == 1 then assert(not skipFirstOrderBy) assert(not next(self._sortValueCache)) for _, uuid in ipairs(self._result) do self._sortValueCache[uuid] = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[1])) end Table.Sort(self._result, self._singleSortWrapper) wipe(self._sortValueCache) elseif skipFirstOrderBy and #self._orderBy == 2 then -- the result is already ordered by the first orderBy field, so iterate through it -- and sort each group of results where the first orderBy field is the same assert(not next(self._sortValueCache)) local group = TempTable.Acquire() local subsetLen = 0 local currentSortValue = nil for i = 1, #self._result do local uuid = self._result[i] local sortValue = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[1])) self._sortValueCache[uuid] = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[2])) if sortValue ~= currentSortValue then -- the first sort value changed, so we're now in a new group if subsetLen > 1 then -- sort the previous group Table.Sort(group, self._secondarySortWrapper) -- update the corresponding results local offset = i - subsetLen - 1 for j = 1, subsetLen do self._result[offset + j] = group[j] end end subsetLen = 0 wipe(group) currentSortValue = sortValue end subsetLen = subsetLen + 1 group[subsetLen] = uuid end if subsetLen > 1 then -- sort the previous group Table.Sort(group, self._secondarySortWrapper) -- update the corresponding results local offset = #self._result - subsetLen for i = 1, subsetLen do self._result[offset + i] = group[i] end end TempTable.Release(group) wipe(self._sortValueCache) else Table.Sort(self._result, self._genericSortWrapper) end end -- update the dependencies wipe(self._resultDependencies) if next(self._virtualFieldFunc) then self._resultDependencies._all = true else for i = 1, #self._joinFields do self._resultDependencies[self._joinFields[i]] = true end for i = 1, #self._orderBy do self._resultDependencies[self._orderBy[i]] = true end if self._distinct then self._resultDependencies[self._distinct] = true end for i = 1, #self._select do self._resultDependencies[self._select[i]] = true end for field in self._db:FieldIterator() do if self._rootClause:_UsesField(field) then self._resultDependencies[field] = true end end end self._resultIsStale = false end function DatabaseQuery._GetQueryIndexInfo(self) -- try to find the index with the least result rows local indexField, indexFirstIndex, indexLastIndex, indexIsStrict = nil, nil, nil, false local bestIndexDiff = math.huge for _, field in ipairs(self._db:_GetIndexAndUniqueList()) do local valueMin, valueMax = self:_IndexValueHelper(strsplit(Constants.DB_INDEX_FIELD_SEP, field)) if valueMin == nil and valueMax == nil then -- continue elseif self._db:_IsUnique(field) and valueMin == valueMax then -- unique indexes result in a single row, at which point the benefit of trying to find something better (EMPTY) is negligible return "UNIQUE", field, valueMin elseif self._db:_IsIndex(field) then -- check how many rows this index results in local indexList = self._db:_GetAllRowsByIndex(field) local firstIndex = valueMin and self._db:_IndexListBinarySearch(field, valueMin, true) or min(1, #indexList) local lastIndex = valueMax and self._db:_IndexListBinarySearch(field, valueMax, false) or #indexList local indexDiff = lastIndex - firstIndex if indexDiff < 0 then -- there are no results within this index, so this is as good as it gets return "EMPTY", field else -- NOTE: string indexes can't be strict since they are case-insensitive local isStrict = type(valueMin) ~= "string" and type(valueMax) ~= "string" and self._rootClause:_IsStrictIndex(field, valueMin, valueMax) if isStrict then -- rough estimate that being able to skip the query makes each row cost 1/4 as much indexDiff = floor(indexDiff / 4) end if indexDiff < bestIndexDiff then -- this is our new best index indexField = field indexFirstIndex = firstIndex indexLastIndex = lastIndex indexIsStrict = isStrict bestIndexDiff = indexDiff end end end end if indexField then return "INDEX", indexField, indexFirstIndex, indexLastIndex, indexIsStrict end -- try the trigram index local trigramIndexField = self._db:_GetTrigramIndexField() if trigramIndexField then local trigramIndexValue = self._rootClause:_GetTrigramIndexValue(trigramIndexField) if trigramIndexValue then return "TRIGRAM", trigramIndexField, trigramIndexValue end end return "NONE" end function DatabaseQuery._AddResultRowsFromIndex(self, indexList, skipQuery, firstIndex, lastIndex, isAscending, indexField) local numJoinDBs = #self._joinDBs local distinct = self._distinct local result = self._result local resultIndex = #self._result + 1 local resultRowLookup = self._resultRowLookup for i = isAscending and firstIndex or lastIndex, isAscending and lastIndex or firstIndex, isAscending and 1 or -1 do local uuid = indexList[i] if skipQuery and numJoinDBs == 0 and not distinct then -- fast path where there's no further filtering so we add all rows result[resultIndex] = uuid resultIndex = resultIndex + 1 resultRowLookup[uuid] = false elseif self:_ResultShouldIncludeRow(uuid, skipQuery, numJoinDBs, distinct, indexField) then result[resultIndex] = uuid resultIndex = resultIndex + 1 resultRowLookup[uuid] = false end end end function DatabaseQuery._AddResultRowsCheckAll(self) local numJoinDBs = #self._joinDBs local distinct = self._distinct local result = self._result local resultIndex = #self._result + 1 local resultRowLookup = self._resultRowLookup for _, uuid in self._db:_UUIDIterator() do if self:_ResultShouldIncludeRow(uuid, false, numJoinDBs, distinct) then result[resultIndex] = uuid resultIndex = resultIndex + 1 resultRowLookup[uuid] = false end end end function DatabaseQuery._ResultShouldIncludeRow(self, uuid, skipQuery, numJoinDBs, distinct, ignoreField) for i = 1, numJoinDBs do if self._joinTypes[i] == "INNER" then local joinField = self._joinFields[i] if not self._joinDBs[i]:_GetUniqueRow(joinField, self._db:GetRowFieldByUUID(uuid, joinField)) then return false end end end if not skipQuery then self._tempResultRow:_SetUUID(uuid) if not self._rootClause:_IsTrue(self._tempResultRow, ignoreField) then return false end end if distinct then local distinctValue = self:_GetResultRowData(uuid, distinct) if self._iterDistinctUsed[distinctValue] then return false end self._iterDistinctUsed[distinctValue] = true end return true end function DatabaseQuery._CreateResultRow(self, uuid) assert(self._resultRowLookup[uuid] == false) local row = QueryResultRow.Get() row:_Acquire(self._db, self) row:_SetUUID(uuid) self._resultRowLookup[uuid] = row return row end function DatabaseQuery._IndexValueHelper(self, ...) local num = select("#", ...) local valueMin, valueMax = nil, nil for i = 1, num do local fieldPart = select(i, ...) local partValueMin, partValueMax = self._rootClause:_GetIndexValue(fieldPart) if partValueMin == nil and partValueMax == nil then return end if num > 1 and (partValueMin == nil or partValueMax == nil) then -- only use multi-field indexes if there's both a min and max value return end if i > 1 then valueMin = valueMin .. Constants.DB_INDEX_VALUE_SEP .. partValueMin valueMax = valueMax .. Constants.DB_INDEX_VALUE_SEP .. partValueMax else valueMin = partValueMin valueMax = partValueMax end end return valueMin, valueMax end function DatabaseQuery._PassThroughReleaseHelper(self, ...) self:Release() return ... end function DatabaseQuery._GetResultRowData(self, uuid, field) if self._virtualFieldFunc[field] then local argField = self._virtualFieldArgField[field] local argValue = nil if argField then argValue = self:_GetResultRowData(uuid, argField) else if not self._tempVirtualResultRow then self._tempVirtualResultRow = QueryResultRow.Get() self._tempVirtualResultRow:_Acquire(self._db, self) end self._tempVirtualResultRow:_SetUUID(uuid) argValue = self._tempVirtualResultRow end local value = self._virtualFieldFunc[field](argValue) if value == nil then value = self._virtualFieldDefault[field] end if type(value) ~= self._virtualFieldType[field] then error(format("Virtual field value not the correct type (%s, %s, %s)", tostring(argValue), tostring(value), field)) end return value elseif #self._joinDBs == 0 or self._db:_GetFieldType(field) then -- this is a local field return self._db:GetRowFieldByUUID(uuid, field) else -- this is a foreign field local joinDB, joinField, joinType, aggregateJoinField, aggregateJoinQuery = nil, nil, nil, nil, nil for i = 1, #self._joinDBs do local testDB = self._joinDBs[i] local testAggregateJoinField = self._aggregateJoinFields[i] if field == testAggregateJoinField or (not testAggregateJoinField and testDB:_GetFieldType(field)) then if joinDB then error("Multiple joined DBs have this field", 2) end joinDB = testDB joinField = self._joinFields[i] joinType = self._joinTypes[i] aggregateJoinField = testAggregateJoinField aggregateJoinQuery = self._aggregateJoinQueries[i] end end if not joinDB then error("Invalid field: "..tostring(field), 2) end if joinType == "AGGREGATE_SUM" then if not aggregateJoinField or not aggregateJoinQuery then error("Missing aggregate join context: " + tostring(aggregateJoinField) + ", " + tostring(aggregateJoinQuery)) end aggregateJoinQuery:BindParams(self:_GetResultRowData(uuid, joinField)) return aggregateJoinQuery:Sum(aggregateJoinField) elseif joinType == "INNER" or joinType == "LEFT" then if aggregateJoinField or aggregateJoinQuery then error("Unexpected aggregate join context: " + tostring(aggregateJoinField) + ", " + tostring(aggregateJoinQuery)) end local foreignUUID = joinDB:_GetUniqueRow(joinField, self:_GetResultRowData(uuid, joinField)) if foreignUUID then return joinDB:GetRowFieldByUUID(foreignUUID, field) end else error("Unknown join type: "..tostring(joinType)) end end end function DatabaseQuery._JoinHelper(self, db, field, joinType, aggregateField, aggregateQuery) assert(type(field) == "string") local localFieldType = self._virtualFieldType[field] or self._db:_GetFieldType(field) local foreignFieldType = db:_GetFieldType(field) assert(localFieldType, "Local field doesn't exist: "..tostring(field)) assert(foreignFieldType, "Foreign field doesn't exist: "..tostring(field)) assert(localFieldType == foreignFieldType, format("Field types don't match (%s, %s)", tostring(localFieldType), tostring(foreignFieldType))) assert(not Table.KeyByValue(self._joinDBs, db), "Already joining with this DB") if aggregateField then assert(type(aggregateField) == "string") assert(aggregateQuery.__class == DatabaseQuery) assert(strmatch(joinType, "^AGGREGATE_")) assert(not self._db:_GetFieldType(aggregateField), "Local DB contains aggregate field: "..tostring(aggregateField)) assert(db:_GetFieldType(aggregateField), "Foreign DB does not contains aggregate field: "..tostring(aggregateField)) else assert(db:_IsUnique(field), "Field must be unique in foreign DB") assert(not strmatch(joinType, "^AGGREGATE_")) assert(not aggregateQuery) for foreignField in db:FieldIterator() do if foreignField ~= field then assert(not self._db:_GetFieldType(foreignField), "Foreign field conflicts with local DB: "..tostring(foreignField)) end end for virtualField in pairs(self._virtualFieldFunc) do if virtualField ~= field then assert(not db:_GetFieldType(virtualField), "Virtual field conflicts with foreign DB: "..tostring(virtualField)) end end end db:_RegisterQuery(self) tinsert(self._joinTypes, joinType) tinsert(self._joinDBs, db) tinsert(self._joinFields, field) tinsert(self._aggregateJoinFields, aggregateField or false) tinsert(self._aggregateJoinQueries, aggregateQuery or false) self._resultIsStale = true end -- ============================================================================ -- Private Helper Functions -- ============================================================================ function private.DatabaseQuerySortSingle(self, aUUID, bUUID, isAscending) local aValue = self._sortValueCache[aUUID] local bValue = self._sortValueCache[bUUID] if aValue == bValue then -- make the sort stable return aUUID > bUUID elseif aValue == nil then -- sort nil to the end return false elseif bValue == nil then -- sort nil to the end return true elseif isAscending then return aValue < bValue else return aValue > bValue end end function private.DatabaseQuerySortGeneric(self, aUUID, bUUID) for i = 1, #self._orderBy do local orderByField = self._orderBy[i] local aValue = Util.ToIndexValue(self:_GetResultRowData(aUUID, orderByField)) local bValue = Util.ToIndexValue(self:_GetResultRowData(bUUID, orderByField)) if aValue == bValue then -- continue looping elseif aValue == nil then -- sort nil to the end return false elseif bValue == nil then -- sort nil to the end return true elseif self._orderByAscending[i] then return aValue < bValue else return aValue > bValue end end -- make the sort stable return aUUID > bUUID end function private.QueryResultAsUUIDIterator(self, index) index = index + 1 local uuid = self._result[index] if not uuid then assert(self._iteratorState == "IN_PROGRESS") self._iteratorState = "IDLE" if self._autoRelease then self:Release() end return end return index, uuid end function private.QueryResultIterator(self, index) index = index + 1 local uuid = self._result[index] if self._iteratorState == "ABORTED" then uuid = nil elseif self._iteratorState ~= "IN_PROGRESS" and self._iteratorState ~= "IN_PROGRESS_CAN_ABORT" then error("Invalid iteratorState: "..tostring(self._iteratorState)) end if not uuid then assert(self._iteratorState == "IN_PROGRESS" or self._iteratorState == "IN_PROGRESS_CAN_ABORT" or self._iteratorState == "ABORTED") self._iteratorState = "IDLE" if self._autoRelease then self:Release() end return end local numSelectFields = #self._select if numSelectFields == 0 then local row = self._resultRowLookup[uuid] if not row then row = self:_CreateResultRow(uuid) end return index, row elseif #self._joinDBs == 0 and numSelectFields <= 5 then -- as an optimization, we don't need to create a result row if numSelectFields == 1 then return index, self:_GetResultRowData(uuid, self._select[1]) elseif numSelectFields == 2 then return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]) elseif numSelectFields == 3 then return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3]) elseif numSelectFields == 4 then return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3]), self:_GetResultRowData(uuid, self._select[4]) elseif numSelectFields == 5 then return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3]), self:_GetResultRowData(uuid, self._select[4]), self:_GetResultRowData(uuid, self._select[5]) else error("Invalid numSelectFields: "..tostring(numSelectFields)) end else local row = self._resultRowLookup[uuid] if not row then row = self:_CreateResultRow(uuid) end return index, row:GetFields(unpack(self._select)) end end