You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1831 lines
62 KiB

3 years ago
-- ------------------------------------------------------------------------------ --
-- TradeSkillMaster --
-- https://tradeskillmaster.com --
-- All Rights Reserved - Detailed license information included with addon. --
-- ------------------------------------------------------------------------------ --
local TSM = select(2, ...) ---@type TSM
local Query = TSM.Init("Util.DatabaseClasses.Query") ---@class Util.DatabaseClasses.Query
3 years ago
local Constants = TSM.Include("Util.DatabaseClasses.Constants")
local Util = TSM.Include("Util.DatabaseClasses.Util")
local QueryResultRow = TSM.Include("Util.DatabaseClasses.QueryResultRow")
local QueryClause = TSM.Include("Util.DatabaseClasses.QueryClause")
local ObjectPool = TSM.Include("Util.ObjectPool")
local Reactive = TSM.Include("Util.Reactive")
3 years ago
local TempTable = TSM.Include("Util.TempTable")
local Table = TSM.Include("Util.Table")
local Math = TSM.Include("Util.Math")
local LibTSMClass = TSM.Include("LibTSMClass")
local DatabaseQuery = LibTSMClass.DefineClass("DatabaseQuery") ---@class DatabaseQuery
3 years ago
local private = {
objectPool = nil,
smartMapReaderContext = {},
uuidDiffContext = {
inUse = false,
insert = {},
remove = {},
result = {},
uuids = {},
},
3 years ago
}
-- ============================================================================
-- Module Loading
-- ============================================================================
Query:OnModuleLoad(function()
private.objectPool = ObjectPool.New("DATABASE_QUERIES", DatabaseQuery, 1)
end)
-- ============================================================================
-- Module Functions
-- ============================================================================
---Gets a query object.
---@param db DatabaseTable The database table to query
---@return DatabaseQuery @The new database query object
3 years ago
function Query.Get(db)
local clause = private.objectPool:Get()
clause:_Acquire(db)
return clause
end
-- ============================================================================
-- Class Meta Methods
-- ============================================================================
function DatabaseQuery:__init()
3 years ago
self._db = nil
self._rootClause = nil
self._currentClause = nil
self._orderBy = {}
self._orderByAscending = {}
self._distinct = nil
self._updateCallback = nil
self._updateCallbackContext = nil
self._updatesPaused = 0
self._queuedUpdate = false
self._select = {}
self._iteratorState = "IDLE"
self._result = {}
self._resultRowLookup = {}
self._iterDistinctUsed = {}
self._tempResultRow = nil
self._tempVirtualResultRow = nil
self._autoRelease = false
self._resultIsStale = false
self._joinTypes = {}
self._joinDBs = {}
self._joinFields = {}
self._joinForeignFields = {}
3 years ago
self._aggregateJoinQueries = {}
self._virtualFieldFunc = {}
self._virtualFieldArgField = {}
self._virtualFieldType = {}
self._virtualFieldDefault = {}
self._genericSortWrapper = function(a, b)
return private.DatabaseQuerySortGeneric(self, a, b)
end
self._singleSortWrapper = function(a, b)
return private.DatabaseQuerySortSingle(self, a, b, self._orderByAscending[1])
end
self._secondarySortWrapper = function(a, b)
return private.DatabaseQuerySortSingle(self, a, b, self._orderByAscending[2])
end
self._sortValueCache = {}
self._resultDependencies = {}
self._stream = Reactive.CreateStream()
3 years ago
end
function DatabaseQuery:_Acquire(db)
3 years ago
self._db = db
self._db:_RegisterQuery(self)
-- implicit root AND clause
self._rootClause = QueryClause.Get()
3 years ago
:And()
self._currentClause = self._rootClause
self._tempResultRow = QueryResultRow.Get()
self._tempResultRow:_Acquire(self._db, self)
self._resultIsStale = true
3 years ago
end
function DatabaseQuery:_Release()
3 years ago
assert(self._iteratorState == "IDLE")
assert(self._stream:GetNumPublishers() == 0)
3 years ago
-- remove from the database
self._db:_RemoveQuery(self)
self._db = nil
self._rootClause:_Release()
self._rootClause = nil
self._currentClause = nil
self._updateCallback = nil
self._updateCallbackContext = nil
self._updatesPaused = 0
self._queuedUpdate = false
wipe(self._iterDistinctUsed)
self._tempResultRow:Release()
self._tempResultRow = nil
if self._tempVirtualResultRow then
self._tempVirtualResultRow:Release()
self._tempVirtualResultRow = nil
end
self._autoRelease = false
self:_WipeResults()
self:ResetOrderBy()
self:ResetDistinct()
self:ResetSelect()
self:ResetJoins()
self:ResetVirtualFields()
self._resultIsStale = false
wipe(self._resultDependencies)
end
-- ============================================================================
-- Public Class Methods
-- ============================================================================
---Releases the database query.
---
---The database query object will be recycled and must not be accessed after calling this method.
---@param abortIterator boolean Abort any in-progress iterator
function DatabaseQuery:Release(abortIterator)
3 years ago
if abortIterator then
self._iteratorState = "IDLE"
end
self:_Release()
private.objectPool:Recycle(self)
end
---Adds a virtual field to the query.
---@param field string The name of the new virtual field
---@param fieldType string The type of the virtual field
---@param func function A function which takes a row and returns the value of the virtual field
---@param argField? string The field to pass into the function (otherwise passes the entire row)
---@param defaultValue? any The default value to use if the function returns nil
---@return DatabaseQuery @The database query object
function DatabaseQuery:VirtualField(field, fieldType, func, argField, defaultValue)
3 years ago
if self:_GetFieldType(field) or self._virtualFieldFunc[field] then
error("Field already exists: "..tostring(field))
elseif type(func) ~= "function" then
error("Invalid func: "..tostring(func))
elseif fieldType ~= "number" and fieldType ~= "string" and fieldType ~= "boolean" then
error("Field type must be string, number, or boolean")
elseif argField and not self:_GetFieldType(argField) then
error("Arg field doesn't exist: "..tostring(argField))
elseif defaultValue ~= nil and type(defaultValue) ~= fieldType then
error("Invalid defaultValue type: "..tostring(defaultValue))
end
self:_NewVirtualField(field, func, argField, fieldType, defaultValue)
3 years ago
return self
end
---Adds a virtual field with a smart map.
---@param field string The name of the new virtual field
---@param map SmartMapObject The smart map
---@param inputFieldName string The field to use as the input to the smart map
function DatabaseQuery:VirtualSmartMapField(field, map, inputFieldName)
if self:_GetFieldType(field) or self._virtualFieldFunc[field] then
error("Field already exists: "..tostring(field))
elseif self:_GetFieldType(inputFieldName) ~= map:GetKeyType() then
error("Invalid input field type or input field doesn't exists: "..tostring(inputFieldName))
end
self:_NewVirtualField(field, self:_GetSmartMapReader(map), inputFieldName, map:GetValueType(), nil)
return self
end
---Where a field equals a value.
---@param field string The name of the field
---@param value any The value to compare to
---@param otherField? string The name of the other field to compare with
---@return DatabaseQuery @The database query object
function DatabaseQuery:Equal(field, value, otherField)
3 years ago
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:Equal(field, value, otherField)
return self
end
---Where a field does not equals a value.
---@param field string The name of the field
---@param value any The value to compare to
---@param otherField? string The name of the other field to compare with
---@return DatabaseQuery @The database query object
function DatabaseQuery:NotEqual(field, value, otherField)
3 years ago
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:NotEqual(field, value, otherField)
return self
end
---Where a field is less than a value.
---@param field string The name of the field
---@param value any The value to compare to
---@param otherField? string The name of the other field to compare with
---@return DatabaseQuery @The database query object
function DatabaseQuery:LessThan(field, value, otherField)
3 years ago
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:LessThan(field, value, otherField)
return self
end
---Where a field is less than or equal to a value.
---@param field string The name of the field
---@param value any The value to compare to
---@param otherField? string The name of the other field to compare with
---@return DatabaseQuery @The database query object
function DatabaseQuery:LessThanOrEqual(field, value, otherField)
3 years ago
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:LessThanOrEqual(field, value, otherField)
return self
end
---Where a field is greater than a value.
---@param field string The name of the field
---@param value any The value to compare to
---@param otherField? string The name of the other field to compare with
---@return DatabaseQuery @The database query object
function DatabaseQuery:GreaterThan(field, value, otherField)
3 years ago
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:GreaterThan(field, value, otherField)
return self
end
---Where a field is greater than or equal to a value.
---@param field string The name of the field
---@param value any The value to compare to
---@param otherField? string The name of the other field to compare with
---@return DatabaseQuery @The database query object
function DatabaseQuery:GreaterThanOrEqual(field, value, otherField)
3 years ago
if value == Constants.OTHER_FIELD_QUERY_PARAM then
local fieldType = self:_GetFieldType(field)
assert(fieldType and fieldType == self:_GetFieldType(otherField))
elseif value ~= Constants.BOUND_QUERY_PARAM then
assert(self:_GetFieldType(field) == type(value))
end
self:_NewClause()
:GreaterThanOrEqual(field, value, otherField)
return self
end
---Where a string field matches a pattern.
---@param field string The name of the field
---@param value string The pattern to match
---@return DatabaseQuery @The database query object
function DatabaseQuery:Matches(field, value)
3 years ago
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(self:_GetFieldType(field) == "string" and type(value) == "string")
self:_NewClause()
:Matches(field, strlower(value))
return self
end
---Where a string field contains a substring.
---@param field string The name of the field
---@param value string The substring to match
---@return DatabaseQuery @The database query object
function DatabaseQuery:Contains(field, value)
3 years ago
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(self:_GetFieldType(field) == "string" and type(value) == "string")
self:_NewClause()
:Contains(field, strlower(value))
return self
end
---Where a string field starts with a substring.
---@param field string The name of the field
---@param value string The substring to match
---@return DatabaseQuery @The database query object
function DatabaseQuery:StartsWith(field, value)
3 years ago
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(self:_GetFieldType(field) == "string" and type(value) == "string")
self:_NewClause()
:StartsWith(field, strlower(value))
return self
end
---Where a foreign field (obtained via a left join) is nil.
---@param field string The name of the field
---@return DatabaseQuery @The database query object
function DatabaseQuery:IsNil(field)
3 years ago
assert(self:_GetJoinType(field) == "LEFT", "Must be a left join")
self:_NewClause()
:IsNil(field)
return self
end
---Where a foreign field (obtained via a left join) is not nil.
---@param field string The name of the field
---@return DatabaseQuery @The database query object
function DatabaseQuery:IsNotNil(field)
3 years ago
assert(self:_GetJoinType(field) == "LEFT", "Must be a left join")
self:_NewClause()
:IsNotNil(field)
return self
end
---A custom query clause.
---@param func fun(row: DatabaseQueryResultRow, arg: any): boolean The function which gets passed the row being evaulated and
---returns whether or not the query results should include it
---@param arg any An argument to pass to the function
---@return DatabaseQuery @The database query object
function DatabaseQuery:Custom(func, arg)
3 years ago
assert(type(func) == "function")
self:_NewClause()
:Custom(func, arg)
return self
end
---Where the hash of a row equals a value.
---@param fields string[] An ordered list of fields to hash
---@param value number The hash value to compare to
---@return DatabaseQuery @The database query object
function DatabaseQuery:HashEqual(fields, value)
3 years ago
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(type(fields) == "table")
for _, field in ipairs(fields) do
local fieldType = self:_GetFieldType(field)
if not fieldType then
error(format("Field %s doesn't exist", tostring(field)))
elseif fieldType ~= "number" and fieldType ~= "string" then
error(format("Cannot hash field of type %s", fieldType))
end
end
self:_NewClause()
:HashEqual(fields, value)
return self
end
---Where a field exists as a key within a table
---@param field string The name of the field
---@param value table The table to check against
---@return DatabaseQuery @The database query object
function DatabaseQuery:InTable(field, value)
3 years ago
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(type(value) == "table")
self:_NewClause()
:InTable(field, value)
return self
end
---Where a field does not exists as a key within a table
---@param field string The name of the field
---@param value table The table to check against
---@return DatabaseQuery @The database query object
function DatabaseQuery:NotInTable(field, value)
3 years ago
assert(value ~= Constants.BOUND_QUERY_PARAM, "This method does not support bound values")
assert(type(value) == "table")
self:_NewClause()
:NotInTable(field, value)
return self
end
---Starts a nested AND clause.
---
---All of the clauses following this (until the matching `:End()`) must be true for the AND clause to be true.
---@return DatabaseQuery @The database query object
function DatabaseQuery:And()
3 years ago
self._currentClause = self:_NewClause()
:And()
return self
end
---Starts a nested OR clause.
---
---At least one of the clauses following this (until the matching `:End()`) must be true for the OR clause to be true.
---@return DatabaseQuery @The database query object
function DatabaseQuery:Or()
3 years ago
self._currentClause = self:_NewClause()
:Or()
return self
end
---Ends a nested AND/OR clause.
---@return DatabaseQuery @The database query object
function DatabaseQuery:End()
3 years ago
assert(self._currentClause ~= self._rootClause, "No current clause to end")
self._currentClause = self._currentClause:_GetParent()
assert(self._currentClause)
return self
end
---Performs a left join with another table.
---@param db DatabaseTable The database table to join with
---@param field string The field to join on
---@param foreignField? string The foreign field to join on (defaults to `field`)
---@return DatabaseQuery @The database query object
function DatabaseQuery:LeftJoin(db, field, foreignField)
self:_JoinHelper("LEFT", db, field, foreignField or field)
3 years ago
return self
end
---Performs an inner join with another table.
---@param db DatabaseTable The database table to join with
---@param field string The field to join on
---@param foreignField? string The foreign field to join on (defaults to `field`)
---@return DatabaseQuery @The database query object
function DatabaseQuery:InnerJoin(db, field, foreignField)
self:_JoinHelper("INNER", db, field, foreignField or field)
3 years ago
return self
end
---Performs an aggregate join with another table with a summed field.
---@param db DatabaseTable The database to join with
---@param field string The name of the field in the other table to join on
---@param sumField string The name of the field in the other table to sum
---@return DatabaseQuery @The database query object
function DatabaseQuery:AggregateJoinSummed(db, field, sumField)
3 years ago
local query = db:NewQuery()
:Equal(field, Constants.BOUND_QUERY_PARAM)
self:_JoinHelper("AGGREGATE_SUM", db, field, sumField, query)
3 years ago
return self
end
---Order the results by a field.
---
---This may be called multiple times to provide additional ordering constraints. The priority of the ordering will be
---descending as this method is called additional times (meaning the first OrderBy will have highest priority).
---@param field string The name of the field to order by
---@param ascending boolean Whether to order in ascending order (descending otherwise)
---@return DatabaseQuery @The database query object
function DatabaseQuery:OrderBy(field, ascending)
3 years ago
assert(ascending == true or ascending == false)
local fieldType = self:_GetFieldType(field)
if not fieldType then
error(format("Field %s doesn't exist", tostring(field)))
elseif fieldType ~= "number" and fieldType ~= "string" and fieldType ~= "boolean" then
error(format("Cannot order by field of type %s", tostring(fieldType)))
end
tinsert(self._orderBy, field)
tinsert(self._orderByAscending, ascending)
self._resultIsStale = true
return self
end
---Only return distinct results based on a field.
---
---This method can be used to ensure that only the first row for each distinct value of the field is returned.
---@param field string The field to ensure is distinct in the results
---@return DatabaseQuery @The database query object
function DatabaseQuery:Distinct(field)
3 years ago
assert(self:_GetFieldType(field), format("Field %s doesn't exist within local DB", tostring(field)))
self._distinct = field
self._resultIsStale = true
return self
end
---Select specific fields in the result.
---@param ... string The fields to select
---@return DatabaseQuery @The database query object
function DatabaseQuery:Select(...)
3 years ago
assert(#self._select == 0)
local numFields = select("#", ...)
assert(numFields > 0, "Must select at least 1 field")
-- DatabaseRow.GetFields() only supports 10 fields, so we can only support 10 here as well
assert(numFields <= 10, "Select() only supports up to 10 fields")
for i = 1, numFields do
local field = select(i, ...)
tinsert(self._select, field)
end
self._resultIsStale = true
return self
end
---Binds parameters to a prepared query.
---
---The number of arguments should match the number of Constants.BOUND_QUERY_PARAM values in the query's clauses.
---@param ... any The bound parameter values
---@return DatabaseQuery @The database query object
function DatabaseQuery:BindParams(...)
3 years ago
local numFields = select("#", ...)
assert(self._rootClause:_BindParams(...) == numFields, "Invalid number of bound parameters")
self._resultIsStale = true
return self
end
---Set an update callback.
---
---This callback gets called whenever any rows in the underlying database change.
---@param func fun(db: DatabaseQuery, changedUUID: number|nil, context: any) The callback function
---@param context? any A context argument which is passed as the third argument to the callback function
---@return DatabaseQuery @The database query object
function DatabaseQuery:SetUpdateCallback(func, context)
3 years ago
self._updateCallback = func
self._updateCallbackContext = context
return self
end
---Pauses or unpauses callbacks for query updates.
---@param paused boolean Whether or not updates should be paused
---@return DatabaseQuery @The database query object
function DatabaseQuery:SetUpdatesPaused(paused)
3 years ago
self._updatesPaused = self._updatesPaused + (paused and 1 or -1)
assert(self._updatesPaused >= 0)
if self._updatesPaused == 0 and self._queuedUpdate then
self:_DoUpdateCallback()
end
return self
end
---Results iterator.
---
---Note that the iterator must run to completion (don't use `break` or `return` to escape it early).
---@param canAbort boolean Allow the iterator to be aborted if the underlying data is updated which must
---be handled by the caller by calling `IsIteratorAborted()` at the end of each iteration loop
---@return fun(): number, DatabaseRow @An iterator with fields: `index`, row
function DatabaseQuery:Iterator(canAbort)
3 years ago
self:_Execute()
assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause")
assert(self._iteratorState == "IDLE")
assert(not canAbort or (not self._updateCallback and self._stream:GetNumPublishers() == 0))
3 years ago
self._iteratorState = canAbort and "IN_PROGRESS_CAN_ABORT" or "IN_PROGRESS"
self._autoRelease = false
return private.QueryResultIterator, self, 0
end
---Iterates through the results as uuids.
---@return fun(): number, number @An iterator with fields: `index`, `uuid`
function DatabaseQuery:UUIDIterator()
3 years ago
self:_Execute()
assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause")
assert(self._iteratorState == "IDLE")
self._iteratorState = "IN_PROGRESS"
self._autoRelease = false
return private.QueryResultAsUUIDIterator, self, 0
end
---Results iterator which releases upon completion.
---
---Note that the iterator must run to completion (don't use `break` or `return` to escape it early).
---@return fun(): number, ... @An iterator with fields: `index`, ...
function DatabaseQuery:IteratorAndRelease()
3 years ago
self:_Execute()
assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause")
assert(self._iteratorState == "IDLE")
self._iteratorState = "IN_PROGRESS"
self._autoRelease = true
return private.QueryResultIterator, self, 0
end
---Check if the abortable iterator has been aborted.
---@return boolean
function DatabaseQuery:IsIteratorAborted()
3 years ago
if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then
return false
elseif self._iteratorState == "PENDING_ABORT" then
self._iteratorState = "ABORTED"
return true
else
error("Invalid iterator state: "..tostring(self._iteratorState))
end
end
---Prepares a UUID diff against a previous list of UUIDs.
---
---If this function returns true, `DatabaseQuery:UUIDDiffIterator()` must be called and run to completion.
---@param oldUUIDs number[] The list of old UUIDs
---@return boolean @Whether or not a diff was prepared
function DatabaseQuery:UUIDDiffPrepare(oldUUIDs)
self:_Execute()
local context = private.uuidDiffContext
assert(not context.inUse)
context.inUse = true
if not Table.GetDiffOrdered(oldUUIDs, self._result, context.insert, context.remove) then
context.inUse = false
return false
end
-- Add the remove actions in reverse order
while #context.remove > 0 do
local endIndex = tremove(context.remove)
local startIndex = endIndex
while #context.remove > 0 and context.remove[#context.remove] == startIndex - 1 do
startIndex = tremove(context.remove)
end
tinsert(context.result, "REMOVE")
tinsert(context.result, startIndex)
tinsert(context.result, endIndex - startIndex + 1)
for i = startIndex, endIndex do
tinsert(context.result, oldUUIDs[i])
end
end
-- Add the insert actions
local i = 1
while i <= #context.insert do
local startIndex = context.insert[i]
local endIndex = startIndex
for j = i + 1, #context.insert do
if context.insert[j] == endIndex + 1 then
endIndex = endIndex + 1
else
break
end
end
tinsert(context.result, "INSERT")
tinsert(context.result, startIndex)
tinsert(context.result, endIndex - startIndex + 1)
for j = startIndex, endIndex do
tinsert(context.result, self._result[j])
end
i = i + endIndex - startIndex + 1
end
wipe(context.insert)
return true
end
---Iterate over the diff prepared with `DatabaseQuery:UUIDDiffPrepare()`.
---@return fun(): number, "REMOVE"|"INSERT", number, number[] @An iterator with fields: `index`, `action`, `startIndex`, `uuids`
function DatabaseQuery:UUIDDiffIterator()
local context = private.uuidDiffContext
assert(context.inUse)
return private.UUIDDiffIterator, context, 1
end
---Populates a table with the results.
---
---The query must have a select clause with exactly one or two fields. In the former case, the table will be populated
---as a list, and in the latter case, the first field must be unique in the results, and will be used as the key for the
---table with the second field being the value.
---@param tbl table The table to store the result in
function DatabaseQuery:AsTable(tbl)
3 years ago
self:_Execute()
if #self._select == 1 then
local field = unpack(self._select)
for _, uuid in ipairs(self._result) do
tinsert(tbl, self:_GetResultRowData(uuid, field))
end
elseif #self._select == 2 then
local field1, field2 = unpack(self._select)
for _, uuid in ipairs(self._result) do
local key = self:_GetResultRowData(uuid, field1)
if key == nil or tbl[key] then
error("Key is nil or not distinct")
end
tbl[key] = self:_GetResultRowData(uuid, field2)
end
else
error("Invalid select clause")
end
return self
end
---Get the number of resulting rows.
---@return number
function DatabaseQuery:Count()
3 years ago
self:_Execute()
return #self._result
end
---Get the number of resulting rows and release the query.
---@return number
function DatabaseQuery:CountAndRelease()
3 years ago
self:_Execute()
local count = #self._result
self:Release()
return count
end
---Get a single result.
---
---This method will assert that there is exactly one result from the query and return it.
---@return any @The result row or the selected fields
function DatabaseQuery:GetSingleResult()
3 years ago
self:_Execute()
assert(self:Count() == 1)
return self:GetFirstResult()
end
---Get a single result and release the query.
---
---This method will assert that there is exactly one result from the query and return it.
---@return any @The result row or the selected fields
function DatabaseQuery:GetSingleResultAndRelease()
3 years ago
assert(#self._select > 0)
local result = self:GetSingleResult()
self:Release()
return result
end
---Get the first result.
---
---Note that this method internally iterates over all the results.
---@return any @The result row or the selected fields
function DatabaseQuery:GetFirstResult()
3 years ago
self:_Execute()
assert(self._iteratorState == "IDLE")
if self:Count() == 0 then
return
end
local uuid = self._result[1]
if not self._resultRowLookup[uuid] then
self:_CreateResultRow(uuid)
end
local row = self._resultRowLookup[uuid]
if #self._select > 0 then
return row:GetFields(unpack(self._select))
else
return row
end
end
---Get the first result and releases the query.
---
---Note that this method internally iterates over all the results.
---@return any @The result row or the selected fields
function DatabaseQuery:GetFirstResultAndRelease()
3 years ago
self:_Execute()
assert(self._iteratorState == "IDLE")
if self:Count() == 0 then
self:Release()
return
end
local uuid = self._result[1]
if not self._resultRowLookup[uuid] then
self:_CreateResultRow(uuid)
end
local row = self._resultRowLookup[uuid]
if #self._select > 0 then
return self:_PassThroughReleaseHelper(row:GetFields(unpack(self._select)))
else
row = row:Clone()
self:Release()
return row
end
end
---Gets the minimum value of a specific field within the query results.
---@param field string The field within the results
---@return number|nil @The minimum value or nil if there are no results
function DatabaseQuery:Min(field)
3 years ago
self:_Execute()
local result = nil
for _, uuid in ipairs(self._result) do
local value = self:_GetResultRowData(uuid, field)
result = min(result or math.huge, value)
end
return result
end
---Gets the maximum value of a specific field within the query results.
---@param field string The field within the results
---@return number|nil @The maximum value or nil if there are no results
function DatabaseQuery:Max(field)
3 years ago
self:_Execute()
local result = nil
for _, uuid in ipairs(self._result) do
local value = self:_GetResultRowData(uuid, field)
result = max(result or -math.huge, value)
end
return result
end
---Gets the summed value of a specific field within the query results.
---@param field string The field within the results
---@return number @The summed value
function DatabaseQuery:Sum(field)
3 years ago
self:_Execute()
local result = 0
for _, uuid in ipairs(self._result) do
result = result + self:_GetResultRowData(uuid, field)
end
return result
end
---Gets the summed value of a specific field for each group within the query results.
---@param groupField string The field to group by
---@param sumField string The field to sum
---@param result table The results table
function DatabaseQuery:GroupedSum(groupField, sumField, result)
3 years ago
self:_Execute()
for _, uuid in ipairs(self._result) do
local group = self:_GetResultRowData(uuid, groupField)
local value = self:_GetResultRowData(uuid, sumField)
result[group] = (result[group] or 0) + value
end
end
---Gets the summed value of a specific field within the query results and releases the query.
---@param field string The field within the results
---@return number @The summed value
function DatabaseQuery:SumAndRelease(field)
3 years ago
self:_Execute()
local result = 0
for _, uuid in ipairs(self._result) do
result = result + self:_GetResultRowData(uuid, field)
end
self:Release()
return result
end
---Gets the average value of a specific field within the query results.
---@param field string The field within the results
---@return number|nil @The average value or nil if there are no results
function DatabaseQuery:Avg(field)
3 years ago
local sum = self:Sum(field)
local num = self:Count()
return num > 0 and (sum / num) or nil
end
---Gets the sum of the products of two fields within the query results.
---@param field1 string The first field within the results
---@param field2 string The second field within the results
---@return number @The summed value
function DatabaseQuery:SumOfProduct(field1, field2)
3 years ago
self:_Execute()
local result = 0
for _, uuid in ipairs(self._result) do
local value1 = self:_GetResultRowData(uuid, field1)
local value2 = self:_GetResultRowData(uuid, field2)
result = result + value1 * value2
end
return result
end
---Joins the string values of a field with a given separator.
---@param field string The field within the results
---@param sep string The separator (can be any number of characters, including an empty string)
---@return string @The joined string
function DatabaseQuery:JoinedString(field, sep)
3 years ago
self:_Execute()
local parts = TempTable.Acquire()
for _, uuid in ipairs(self._result) do
tinsert(parts, self:_GetResultRowData(uuid, field))
end
local result = table.concat(parts, sep)
TempTable.Release(parts)
return result
end
---Calculates the hash of the query results.
---
---Note that either `fields` must be specified or the query must have a select colum with at most 2 fields.
---@param fields? string[] The fields from each row to hash (ottherwise uses the selected fields)
---@return number|nil @The hash value or nil if there are no results
function DatabaseQuery:Hash(fields)
3 years ago
self:_Execute()
local result = nil
if fields then
for _, uuid in ipairs(self._result) do
for _, field in ipairs(fields) do
result = Math.CalculateHash(self:_GetResultRowData(uuid, field), result)
end
end
else
local keyField, valueField, extra = unpack(self._select)
assert(keyField and not extra)
local hashContext = TempTable.Acquire()
for _, uuid in ipairs(self._result) do
tinsert(hashContext, self:_GetResultRowData(uuid, keyField))
if valueField then
tinsert(hashContext, self:_GetResultRowData(uuid, valueField))
end
end
Table.Sort(hashContext)
for _, value in ipairs(hashContext) do
result = Math.CalculateHash(value, result)
end
TempTable.Release(hashContext)
end
return result
end
---Calculates the hash of the query results, grouping by a field.
---@param fields table The fields from each row to hash
---@param groupField string The field to group by
---@param result table The result table
function DatabaseQuery:GroupedHash(fields, groupField, result)
3 years ago
self:_Execute()
for i = 1, #self._result do
local uuid = self._result[i]
local groupValue = self:_GetResultRowData(uuid, groupField)
local rowHash = nil
for j = 1, #fields do
rowHash = Math.CalculateHash(self:_GetResultRowData(uuid, fields[j]), rowHash)
end
result[groupValue] = Math.CalculateHash(rowHash, result[groupValue])
end
end
---Calculates the hash of the query results and release.
---
---Note that either `fields` must be specified or the query must have a select colum with at most 2 fields.
---@param fields? table The fields from each row to hash (ottherwise uses the selected fields)
---@return number|nil @The hash value or nil if there are no results
function DatabaseQuery:HashAndRelease(fields)
3 years ago
local result = self:Hash(fields)
self:Release()
return result
end
---Deletes all the result rows from the database and releases the query.
---@return number @The number of rows deleted (equal to `:Count()`)
function DatabaseQuery:DeleteAndRelease()
3 years ago
local count = self:Count()
self._db:BulkDelete(self._result)
self:Release()
return count
end
---Resets the database query.
---@return DatabaseQuery @The database query object
function DatabaseQuery:Reset()
3 years ago
self:ResetDistinct()
self:ResetSelect()
self:ResetOrderBy()
self:ResetJoins()
self:ResetFilters()
self:ResetVirtualFields()
self:_WipeResults()
self._resultIsStale = true
return self
end
---Resets any virtual fields added to the database query.
---@return DatabaseQuery @The database query object
function DatabaseQuery:ResetVirtualFields()
for _, func in pairs(self._virtualFieldFunc) do
if private.smartMapReaderContext[func] then
private.smartMapReaderContext[func].query = nil
end
end
3 years ago
wipe(self._virtualFieldFunc)
wipe(self._virtualFieldArgField)
wipe(self._virtualFieldType)
wipe(self._virtualFieldDefault)
self._resultIsStale = true
return self
end
---Resets any filtering clauses of the database query.
---@return DatabaseQuery @The database query object
function DatabaseQuery:ResetFilters()
3 years ago
self._rootClause:_Release()
self._rootClause = QueryClause.Get()
3 years ago
:And()
self._currentClause = self._rootClause
self._resultIsStale = true
return self
end
---Resets any ordering clauses of the database query.
---@return DatabaseQuery @The database query object
function DatabaseQuery:ResetOrderBy()
3 years ago
wipe(self._orderBy)
wipe(self._orderByAscending)
self._resultIsStale = true
return self
end
---Resets any joins of the database query.
---@return DatabaseQuery @The database query object
function DatabaseQuery:ResetJoins()
3 years ago
for _, db in ipairs(self._joinDBs) do
db:_RemoveQuery(self)
end
wipe(self._joinTypes)
wipe(self._joinDBs)
wipe(self._joinFields)
wipe(self._joinForeignFields)
3 years ago
for _, query in ipairs(self._aggregateJoinQueries) do
if query then
query:Release()
end
end
wipe(self._aggregateJoinQueries)
self._resultIsStale = true
return self
end
---Resets any distinct clauses of the database query.
---@return DatabaseQuery @The database query object
function DatabaseQuery:ResetDistinct()
3 years ago
self._distinct = nil
self._resultIsStale = true
return self
end
---Resets any select clauses of the database query.
---@return DatabaseQuery @The database query object
function DatabaseQuery:ResetSelect()
3 years ago
wipe(self._select)
self._resultIsStale = true
return self
end
---Gets info on a specific order by clause.
---@param index number The index of the order by clause
---@return string? @The field name
---@return boolean|nil @Whether or not the sort is ascending
function DatabaseQuery:GetOrderBy(index)
3 years ago
assert(self._orderBy[index])
return self._orderBy[index], self._orderByAscending[index]
end
---Gets info on the last order by clause.
---@return string? @The field name
---@return boolean|nil @Whether or not the sort is ascending
function DatabaseQuery:GetLastOrderBy()
3 years ago
return self._orderBy[#self._orderBy], self._orderByAscending[#self._orderByAscending]
end
---Updates the last order by clause.
---@param field string The name of the field to order by
---@param ascending boolean Whether to order in ascending order (descending otherwise)
---@return DatabaseQuery @The database query object
function DatabaseQuery:UpdateLastOrderBy(field, ascending)
3 years ago
assert(#self._orderBy > 0)
tremove(self._orderBy)
tremove(self._orderByAscending)
self:OrderBy(field, ascending)
return self
end
---Get a result row by its UUID.
---@param uuid number The UUID of the row to get
---@return DatabaseRow @The result row
function DatabaseQuery:GetResultRowByUUID(uuid)
3 years ago
if not self._resultRowLookup[uuid] then
self:_CreateResultRow(uuid)
end
return self._resultRowLookup[uuid]
end
---Get the selected fields by UUID.
---@param uuid number The UUID of the row to get
---@return ... @The selected fields
function DatabaseQuery:GetSelectedFieldsByUUID(uuid)
assert(#self._select > 0)
if not self._resultRowLookup[uuid] then
self:_CreateResultRow(uuid)
end
local result = TempTable.Acquire(unpack(self._select))
for i, field in ipairs(result) do
result[i] = self:_GetResultRowData(uuid, field)
end
return TempTable.UnpackAndRelease(result)
end
---Gets a publisher for query result changes.
---@return ReactivePublisher
function DatabaseQuery:Publisher()
return self._stream:PublisherWithInitialValue(nil)
end
3 years ago
-- ============================================================================
-- Private Class Methods
-- ============================================================================
function DatabaseQuery:_GetJoinType(field)
3 years ago
for i, db in ipairs(self._joinDBs) do
if db:_GetFieldType(field) then
return self._joinTypes[i]
end
end
end
function DatabaseQuery:_GetFieldType(field)
3 years ago
local fieldType = self._virtualFieldType[field] or self._db:_GetFieldType(field)
if fieldType then
return fieldType
end
for i, db in ipairs(self._joinDBs) do
if field == self._joinForeignFields[i] and self._joinTypes[i] == "AGGREGATE_SUM" then
fieldType = "number"
3 years ago
else
fieldType = db:_GetFieldType(field)
end
if fieldType then
return fieldType
end
end
end
function DatabaseQuery:_MarkResultStale(changedFields)
3 years ago
assert(self._iteratorState == "IDLE" or self._iteratorState == "IN_PROGRESS_CAN_ABORT" or self._iteratorState == "PENDING_ABORT")
if self._resultIsStale then
-- already marked stale
return
end
if self._resultDependencies._all or not changedFields then
-- either the result depends on all fields or we weren't given a table of changed fields
self._resultIsStale = true
if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then
self._iteratorState = "PENDING_ABORT"
end
return
end
-- check if any of the fields our result is based on changed
for field in pairs(changedFields) do
if self._resultDependencies[field] then
self._resultIsStale = true
if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then
self._iteratorState = "PENDING_ABORT"
end
return
end
end
-- clear the cached values for the changed fields
for _, row in pairs(self._resultRowLookup) do
if row ~= false then
for field in pairs(changedFields) do
rawset(row, field, nil)
end
end
end
if self._iteratorState == "IN_PROGRESS_CAN_ABORT" then
self._iteratorState = "PENDING_ABORT"
end
end
function DatabaseQuery:_DoUpdateCallback(uuid)
if not self._updateCallback and self._stream:GetNumPublishers() == 0 then
3 years ago
assert(self._iteratorState == "IDLE" or self._iteratorState == "PENDING_ABORT")
return
end
-- can't have an update callback on an abortable iterator
assert(self._iteratorState == "IDLE")
if self._updatesPaused > 0 then
self._queuedUpdate = true
else
self._queuedUpdate = false
-- Pause query updates while processing this one so we don't end up recursing
self:SetUpdatesPaused(true)
local updatedUUID = nil
3 years ago
if self._resultIsStale or not uuid then
updatedUUID = nil
3 years ago
elseif self._db:_ContainsUUID(uuid) then
updatedUUID = uuid
3 years ago
else
-- the UUID is from a joined DB, so see if we can easily translate it to a local UUID
local localUUID = nil
for i = 1, #self._joinDBs do
local joinDB = self._joinDBs[i]
if not self._aggregateJoinQueries[i] and joinDB:_ContainsUUID(uuid) then
3 years ago
if localUUID then
-- found more than once, so bail
localUUID = nil
break
end
local joinField = self._joinFields[i]
local joinForeignField = self._joinForeignFields[i]
local joinValue = joinDB:GetRowFieldByUUID(uuid, joinForeignField)
3 years ago
if self._db:_IsUnique(joinField) then
localUUID = self._db:_GetUniqueRow(joinField, joinValue)
elseif self._db:_IsIndex(joinField) then
local lowIndex, highIndex = self._db:_GetIndexListMatchingIndexRange(joinField, Util.ToIndexValue(joinValue))
if not lowIndex or not highIndex or lowIndex ~= highIndex then
-- can't use this index to find a single local UUID
break
end
localUUID = self._db:_GetAllRowsByIndex(joinField)[lowIndex]
end
end
end
updatedUUID = localUUID
3 years ago
end
self._stream:Send(updatedUUID)
if self._updateCallback then
self:_updateCallback(updatedUUID, self._updateCallbackContext)
end
self:SetUpdatesPaused(false)
3 years ago
end
end
function DatabaseQuery:_NewClause()
assert(self._iteratorState == "IDLE")
local newClause = QueryClause.Get(self._currentClause)
3 years ago
self._currentClause:_InsertSubClause(newClause)
self._resultIsStale = true
3 years ago
return newClause
end
function DatabaseQuery:_NewVirtualField(field, func, argField, fieldType, defaultValue)
assert(self._iteratorState == "IDLE")
self._virtualFieldFunc[field] = func
self._virtualFieldArgField[field] = argField
self._virtualFieldType[field] = fieldType
self._virtualFieldDefault[field] = defaultValue
self._resultIsStale = true
end
function DatabaseQuery:_WipeResults()
3 years ago
for _, row in pairs(self._resultRowLookup) do
if row ~= false then
row:Release()
end
end
wipe(self._result)
wipe(self._resultRowLookup)
if self._updatesPaused > 0 then
self._queuedUpdate = true
end
3 years ago
end
function DatabaseQuery:_Execute()
if not self._resultIsStale then
3 years ago
return
end
assert(self._rootClause and self._currentClause == self._rootClause, "Did not end sub-clause")
assert(self._iteratorState == "IDLE")
assert(not next(self._iterDistinctUsed))
-- clear the current result
self:_WipeResults()
-- get all the rows which we need to iterate over
local firstOrderBy = self._orderBy[1]
local skipFirstOrderBy = false
local sortNeeded = firstOrderBy and true or false
local indexType, indexField, indexArg1, indexArg2, indexArg3 = self:_GetQueryIndexInfo()
self._result._queryOptimizationResult = indexType
self._result._queryOptimizationField = indexField
if indexType == "EMPTY" then
sortNeeded = false
elseif indexType == "UNIQUE" then
-- we are looking for a unique row
local indexValue = indexArg1
local uuid = self._db:_GetUniqueRow(indexField, indexValue)
if uuid and self:_ResultShouldIncludeRow(uuid, false, #self._joinDBs, self._distinct) then
tinsert(self._result, uuid)
self._resultRowLookup[uuid] = false
end
sortNeeded = false
elseif indexType == "INDEX" then
-- we're querying on an index, so use that index to populate the result
local firstIndex, lastIndex, isStrict = indexArg1, indexArg2, indexArg3
local isAscending = true
if firstOrderBy and indexField == firstOrderBy then
-- we're also ordering by this field so can skip the first OrderBy field
self._result._queryOptimizationResult = "INDEX_AND_ORDER_BY"
skipFirstOrderBy = true
sortNeeded = #self._orderBy > 1
isAscending = self._orderByAscending[1]
end
local indexList = self._db:_GetAllRowsByIndex(indexField)
self:_AddResultRowsFromIndex(indexList, isStrict, firstIndex, lastIndex, isAscending, indexField)
elseif indexType == "NONE" then
if firstOrderBy and self._db:_IsIndex(firstOrderBy) then
-- we're ordering on an index, so use that index to iterate through all the rows in order to skip the first OrderBy field
self._result._queryOptimizationResult = "ORDER_BY"
self._result._queryOptimizationField = firstOrderBy
skipFirstOrderBy = true
sortNeeded = #self._orderBy > 1
local isAscending = self._orderByAscending[1]
local indexList = self._db:_GetAllRowsByIndex(firstOrderBy)
self:_AddResultRowsFromIndex(indexList, false, 1, #indexList, isAscending)
else
-- no optimizations
self:_AddResultRowsCheckAll()
end
elseif indexType == "TRIGRAM" then
local indexValue = indexArg1
local uuids = TempTable.Acquire()
self._db:_GetTrigramIndexMatchingRows(indexValue, uuids)
self:_AddResultRowsFromIndex(uuids, false, 1, #uuids, true)
TempTable.Release(uuids)
else
error("Invalid index type: "..tostring(indexType))
end
wipe(self._iterDistinctUsed)
-- sort the results if necessary
if sortNeeded then
if #self._orderBy == 1 then
assert(not skipFirstOrderBy)
assert(not next(self._sortValueCache))
for _, uuid in ipairs(self._result) do
self._sortValueCache[uuid] = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[1]))
end
Table.Sort(self._result, self._singleSortWrapper)
wipe(self._sortValueCache)
elseif skipFirstOrderBy and #self._orderBy == 2 then
-- the result is already ordered by the first orderBy field, so iterate through it
-- and sort each group of results where the first orderBy field is the same
assert(not next(self._sortValueCache))
local group = TempTable.Acquire()
local subsetLen = 0
local currentSortValue = nil
for i = 1, #self._result do
local uuid = self._result[i]
local sortValue = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[1]))
self._sortValueCache[uuid] = Util.ToIndexValue(self:_GetResultRowData(uuid, self._orderBy[2]))
if sortValue ~= currentSortValue then
-- the first sort value changed, so we're now in a new group
if subsetLen > 1 then
-- sort the previous group
Table.Sort(group, self._secondarySortWrapper)
-- update the corresponding results
local offset = i - subsetLen - 1
for j = 1, subsetLen do
self._result[offset + j] = group[j]
end
end
subsetLen = 0
wipe(group)
currentSortValue = sortValue
end
subsetLen = subsetLen + 1
group[subsetLen] = uuid
end
if subsetLen > 1 then
-- sort the previous group
Table.Sort(group, self._secondarySortWrapper)
-- update the corresponding results
local offset = #self._result - subsetLen
for i = 1, subsetLen do
self._result[offset + i] = group[i]
end
end
TempTable.Release(group)
wipe(self._sortValueCache)
else
Table.Sort(self._result, self._genericSortWrapper)
end
end
-- update the dependencies
wipe(self._resultDependencies)
if next(self._virtualFieldFunc) then
self._resultDependencies._all = true
else
for i = 1, #self._joinFields do
self._resultDependencies[self._joinFields[i]] = true
end
for i = 1, #self._orderBy do
self._resultDependencies[self._orderBy[i]] = true
end
if self._distinct then
self._resultDependencies[self._distinct] = true
end
for i = 1, #self._select do
self._resultDependencies[self._select[i]] = true
end
for field in self._db:FieldIterator() do
if self._rootClause:_UsesField(field) then
self._resultDependencies[field] = true
end
end
end
self._resultIsStale = false
end
function DatabaseQuery:_GetQueryIndexInfo()
3 years ago
-- try to find the index with the least result rows
local indexField, indexFirstIndex, indexLastIndex, indexIsStrict = nil, nil, nil, false
local bestIndexDiff = math.huge
for _, field in ipairs(self._db:_GetIndexAndUniqueList()) do
local valueMin, valueMax = self:_IndexValueHelper(field)
3 years ago
if valueMin == nil and valueMax == nil then
-- continue
elseif self._db:_IsUnique(field) and valueMin == valueMax then
-- unique indexes result in a single row, at which point the benefit of trying to find something better (EMPTY) is negligible
return "UNIQUE", field, valueMin
elseif self._db:_IsIndex(field) then
-- check how many rows this index results in
local indexList = self._db:_GetAllRowsByIndex(field)
local firstIndex = valueMin and self._db:_IndexListBinarySearch(field, valueMin, true) or min(1, #indexList)
local lastIndex = valueMax and self._db:_IndexListBinarySearch(field, valueMax, false) or #indexList
local indexDiff = lastIndex - firstIndex
if indexDiff < 0 then
-- there are no results within this index, so this is as good as it gets
return "EMPTY", field
else
-- NOTE: string indexes can't be strict since they are case-insensitive
local isStrict = type(valueMin) ~= "string" and type(valueMax) ~= "string" and self._rootClause:_IsStrictIndex(field, valueMin, valueMax)
if isStrict then
-- rough estimate that being able to skip the query makes each row cost 1/4 as much
indexDiff = floor(indexDiff / 4)
end
if indexDiff < bestIndexDiff then
-- this is our new best index
indexField = field
indexFirstIndex = firstIndex
indexLastIndex = lastIndex
indexIsStrict = isStrict
bestIndexDiff = indexDiff
end
end
end
end
if indexField then
return "INDEX", indexField, indexFirstIndex, indexLastIndex, indexIsStrict
end
-- try the trigram index
local trigramIndexField = self._db:_GetTrigramIndexField()
if trigramIndexField then
local trigramIndexValue = self._rootClause:_GetTrigramIndexValue(trigramIndexField)
if trigramIndexValue then
return "TRIGRAM", trigramIndexField, trigramIndexValue
end
end
return "NONE"
end
function DatabaseQuery:_AddResultRowsFromIndex(indexList, skipQuery, firstIndex, lastIndex, isAscending, indexField)
3 years ago
local numJoinDBs = #self._joinDBs
local distinct = self._distinct
local result = self._result
local resultIndex = #self._result + 1
local resultRowLookup = self._resultRowLookup
for i = isAscending and firstIndex or lastIndex, isAscending and lastIndex or firstIndex, isAscending and 1 or -1 do
local uuid = indexList[i]
if skipQuery and numJoinDBs == 0 and not distinct then
-- fast path where there's no further filtering so we add all rows
result[resultIndex] = uuid
resultIndex = resultIndex + 1
resultRowLookup[uuid] = false
elseif self:_ResultShouldIncludeRow(uuid, skipQuery, numJoinDBs, distinct, indexField) then
result[resultIndex] = uuid
resultIndex = resultIndex + 1
resultRowLookup[uuid] = false
end
end
end
function DatabaseQuery:_AddResultRowsCheckAll()
3 years ago
local numJoinDBs = #self._joinDBs
local distinct = self._distinct
local result = self._result
local resultIndex = #self._result + 1
local resultRowLookup = self._resultRowLookup
for _, uuid in self._db:_UUIDIterator() do
if self:_ResultShouldIncludeRow(uuid, false, numJoinDBs, distinct) then
result[resultIndex] = uuid
resultIndex = resultIndex + 1
resultRowLookup[uuid] = false
end
end
end
function DatabaseQuery:_ResultShouldIncludeRow(uuid, skipQuery, numJoinDBs, distinct, ignoreField)
3 years ago
for i = 1, numJoinDBs do
if self._joinTypes[i] == "INNER" then
local joinField = self._joinFields[i]
local joinForeignField = self._joinForeignFields[i]
if not self._joinDBs[i]:_GetUniqueRow(joinForeignField, self:_GetResultRowData(uuid, joinField)) then
3 years ago
return false
end
end
end
if not skipQuery then
self._tempResultRow:_SetUUID(uuid)
if not self._rootClause:_IsTrue(self._tempResultRow, ignoreField) then
return false
end
end
if distinct then
local distinctValue = self:_GetResultRowData(uuid, distinct)
if self._iterDistinctUsed[distinctValue] then
return false
end
self._iterDistinctUsed[distinctValue] = true
end
return true
end
function DatabaseQuery:_CreateResultRow(uuid)
3 years ago
assert(self._resultRowLookup[uuid] == false)
local row = QueryResultRow.Get()
row:_Acquire(self._db, self)
row:_SetUUID(uuid)
self._resultRowLookup[uuid] = row
return row
end
function DatabaseQuery:_IndexValueHelper(...)
3 years ago
local num = select("#", ...)
local valueMin, valueMax = nil, nil
for i = 1, num do
local fieldPart = select(i, ...)
local partValueMin, partValueMax = self._rootClause:_GetIndexValue(fieldPart)
if partValueMin == nil and partValueMax == nil then
return
end
if num > 1 and (partValueMin == nil or partValueMax == nil) then
-- only use multi-field indexes if there's both a min and max value
return
end
if i > 1 then
valueMin = valueMin .. Constants.DB_INDEX_VALUE_SEP .. partValueMin
valueMax = valueMax .. Constants.DB_INDEX_VALUE_SEP .. partValueMax
else
valueMin = partValueMin
valueMax = partValueMax
end
end
return valueMin, valueMax
end
function DatabaseQuery:_PassThroughReleaseHelper(...)
3 years ago
self:Release()
return ...
end
function DatabaseQuery:_GetResultRowData(uuid, field)
3 years ago
if self._virtualFieldFunc[field] then
local argField = self._virtualFieldArgField[field]
local argValue = nil
if argField then
argValue = self:_GetResultRowData(uuid, argField)
else
if not self._tempVirtualResultRow then
self._tempVirtualResultRow = QueryResultRow.Get()
self._tempVirtualResultRow:_Acquire(self._db, self)
end
self._tempVirtualResultRow:_SetUUID(uuid)
argValue = self._tempVirtualResultRow
end
local value = self._virtualFieldFunc[field](argValue)
if value == nil then
value = self._virtualFieldDefault[field]
end
if type(value) ~= self._virtualFieldType[field] then
error(format("Virtual field value not the correct type (%s, %s, %s)", tostring(argValue), tostring(value), field))
end
return value
elseif #self._joinDBs == 0 or self._db:_GetFieldType(field) then
-- this is a local field
return self._db:GetRowFieldByUUID(uuid, field)
else
-- this is a foreign field
local joinDB, joinField, joinForeignField, joinType, aggregateJoinField, aggregateJoinQuery = nil, nil, nil, nil, nil, nil
3 years ago
for i = 1, #self._joinDBs do
local testDB = self._joinDBs[i]
local testAggregateJoinField = self._aggregateJoinQueries[i] and self._joinForeignFields[i] or nil
3 years ago
if field == testAggregateJoinField or (not testAggregateJoinField and testDB:_GetFieldType(field)) then
if joinDB then
error("Multiple joined DBs have this field", 2)
end
joinDB = testDB
joinField = self._joinFields[i]
joinForeignField = self._joinForeignFields[i]
3 years ago
joinType = self._joinTypes[i]
aggregateJoinField = testAggregateJoinField
aggregateJoinQuery = self._aggregateJoinQueries[i]
end
end
if not joinDB then
error("Invalid field: "..tostring(field), 2)
end
if joinType == "AGGREGATE_SUM" then
if not aggregateJoinField or not aggregateJoinQuery then
error("Missing aggregate join context: " + tostring(aggregateJoinField) + ", " + tostring(aggregateJoinQuery))
end
aggregateJoinQuery:BindParams(self:_GetResultRowData(uuid, joinField))
return aggregateJoinQuery:Sum(aggregateJoinField)
elseif joinType == "INNER" or joinType == "LEFT" then
if aggregateJoinField or aggregateJoinQuery then
error("Unexpected aggregate join context: " + tostring(aggregateJoinField) + ", " + tostring(aggregateJoinQuery))
end
local foreignUUID = joinDB:_GetUniqueRow(joinForeignField, self:_GetResultRowData(uuid, joinField))
3 years ago
if foreignUUID then
return joinDB:GetRowFieldByUUID(foreignUUID, field)
end
else
error("Unknown join type: "..tostring(joinType))
end
end
end
function DatabaseQuery:_JoinHelper(joinType, db, field, foreignField, aggregateQuery)
assert(type(field) == "string" and type(foreignField) == "string")
3 years ago
local localFieldType = self._virtualFieldType[field] or self._db:_GetFieldType(field)
local foreignFieldType = db:_GetFieldType(foreignField)
assert(localFieldType, "Local field doesn't exist: "..field)
assert(foreignFieldType, "Foreign field doesn't exist: "..foreignField)
3 years ago
assert(not Table.KeyByValue(self._joinDBs, db), "Already joining with this DB")
assert(self._iteratorState == "IDLE")
if aggregateQuery then
3 years ago
assert(aggregateQuery.__class == DatabaseQuery)
assert(strmatch(joinType, "^AGGREGATE_"))
assert(not self._db:_GetFieldType(foreignField), "Local DB contains aggregate field: "..tostring(foreignField))
assert(db:_GetFieldType(foreignField), "Foreign DB does not contains aggregate field: "..tostring(foreignField))
3 years ago
else
assert(localFieldType == foreignFieldType, format("Field types don't match (%s, %s, %s, %s)", field, tostring(localFieldType), foreignField, tostring(foreignFieldType)))
assert(db:_IsUnique(foreignField), "Field must be unique in foreign DB")
3 years ago
assert(not strmatch(joinType, "^AGGREGATE_"))
assert(not aggregateQuery)
for dbField in db:FieldIterator() do
if dbField ~= field and dbField ~= foreignField then
assert(not self._db:_GetFieldType(dbField), "Foreign field conflicts with local DB: "..tostring(dbField))
3 years ago
end
end
for virtualField in pairs(self._virtualFieldFunc) do
if virtualField ~= field then
assert(not db:_GetFieldType(virtualField), "Virtual field conflicts with foreign DB: "..tostring(virtualField))
end
end
end
db:_RegisterQuery(self)
tinsert(self._joinTypes, joinType)
tinsert(self._joinDBs, db)
tinsert(self._joinFields, field)
tinsert(self._joinForeignFields, foreignField)
3 years ago
tinsert(self._aggregateJoinQueries, aggregateQuery or false)
self._resultIsStale = true
end
function DatabaseQuery:_GetSmartMapReader(map)
for reader, context in pairs(private.smartMapReaderContext) do
if context.map == map and context.query == nil then
context.query = self
return reader
end
end
local reader = map:CreateReader(private.HandleSmartMapUpdate)
private.smartMapReaderContext[reader] = {
map = map,
query = self,
}
return reader
end
3 years ago
-- ============================================================================
-- Private Helper Functions
-- ============================================================================
function private.DatabaseQuerySortSingle(self, aUUID, bUUID, isAscending)
local aValue = self._sortValueCache[aUUID]
local bValue = self._sortValueCache[bUUID]
if aValue == bValue then
-- make the sort stable
return aUUID > bUUID
elseif aValue == nil then
-- sort nil to the end
return false
elseif bValue == nil then
-- sort nil to the end
return true
elseif isAscending then
return aValue < bValue
else
return aValue > bValue
end
end
function private.DatabaseQuerySortGeneric(self, aUUID, bUUID)
for i = 1, #self._orderBy do
local orderByField = self._orderBy[i]
local aValue = Util.ToIndexValue(self:_GetResultRowData(aUUID, orderByField))
local bValue = Util.ToIndexValue(self:_GetResultRowData(bUUID, orderByField))
if aValue == bValue then
-- continue looping
elseif aValue == nil then
-- sort nil to the end
return false
elseif bValue == nil then
-- sort nil to the end
return true
elseif self._orderByAscending[i] then
return aValue < bValue
else
return aValue > bValue
end
end
-- make the sort stable
return aUUID > bUUID
end
function private.QueryResultAsUUIDIterator(self, index)
index = index + 1
local uuid = self._result[index]
if not uuid then
assert(self._iteratorState == "IN_PROGRESS")
self._iteratorState = "IDLE"
if self._autoRelease then
self:Release()
end
return
end
return index, uuid
end
function private.QueryResultIterator(self, index)
index = index + 1
local uuid = self._result[index]
if self._iteratorState == "ABORTED" then
uuid = nil
elseif self._iteratorState ~= "IN_PROGRESS" and self._iteratorState ~= "IN_PROGRESS_CAN_ABORT" then
error("Invalid iteratorState: "..tostring(self._iteratorState))
end
if not uuid then
assert(self._iteratorState == "IN_PROGRESS" or self._iteratorState == "IN_PROGRESS_CAN_ABORT" or self._iteratorState == "ABORTED")
self._iteratorState = "IDLE"
if self._autoRelease then
self:Release()
end
return
end
local numSelectFields = #self._select
if numSelectFields == 0 then
local row = self._resultRowLookup[uuid]
if not row then
row = self:_CreateResultRow(uuid)
end
return index, row
elseif #self._joinDBs == 0 and numSelectFields <= 5 then
-- as an optimization, we don't need to create a result row
if numSelectFields == 1 then
return index, self:_GetResultRowData(uuid, self._select[1])
elseif numSelectFields == 2 then
return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2])
elseif numSelectFields == 3 then
return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3])
elseif numSelectFields == 4 then
return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3]), self:_GetResultRowData(uuid, self._select[4])
elseif numSelectFields == 5 then
return index, self:_GetResultRowData(uuid, self._select[1]), self:_GetResultRowData(uuid, self._select[2]), self:_GetResultRowData(uuid, self._select[3]), self:_GetResultRowData(uuid, self._select[4]), self:_GetResultRowData(uuid, self._select[5])
else
error("Invalid numSelectFields: "..tostring(numSelectFields))
end
else
local row = self._resultRowLookup[uuid]
if not row then
row = self:_CreateResultRow(uuid)
end
return index, row:GetFields(unpack(self._select))
end
end
function private.HandleSmartMapUpdate(reader, pendingChanges)
local self = private.smartMapReaderContext[reader].query
if not self then
return
end
local updateFields = TempTable.Acquire()
for field, func in pairs(self._virtualFieldFunc) do
if func == reader then
tinsert(updateFields, field)
break
end
end
assert(#updateFields == 1)
self:_MarkResultStale(updateFields)
TempTable.Release(updateFields)
self:_DoUpdateCallback()
end
function private.UUIDDiffIterator(context, index)
assert(context.inUse)
wipe(context.uuids)
if index > #context.result then
wipe(context.result)
context.inUse = false
return
end
local action = context.result[index]
local startIndex = context.result[index + 1]
local num = context.result[index + 2]
index = index + 3
assert(action == "INSERT" or action == "REMOVE")
assert(startIndex > 0 and num > 0 and num <= #context.result - index + 1)
for i = index, index + num - 1 do
tinsert(context.uuids, context.result[i])
end
return index + num, action, startIndex, context.uuids
end