You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

365 lines
11 KiB

3 years ago
-- ------------------------------------------------------------------------------ --
-- TradeSkillMaster --
-- https://tradeskillmaster.com --
-- All Rights Reserved - Detailed license information included with addon. --
-- ------------------------------------------------------------------------------ --
local TSM = select(2, ...) ---@type TSM
local QueryResultRow = TSM.Init("Util.DatabaseClasses.QueryResultRow") ---@class Util.DatabaseClasses.QueryResultRow
3 years ago
local Math = TSM.Include("Util.Math")
local TempTable = TSM.Include("Util.TempTable")
local ObjectPool = TSM.Include("Util.ObjectPool")
local private = {
context = {},
objectPool = nil,
}
-- ============================================================================
-- Module Loading
3 years ago
-- ============================================================================
QueryResultRow:OnModuleLoad(function()
private.objectPool = ObjectPool.New("DATABASE_QUERY_RESULT_ROWS", private.CreateNew, 2)
end)
3 years ago
-- ============================================================================
-- Module Functions
-- ============================================================================
3 years ago
---Gets a database row object.
---@return DatabaseRow
function QueryResultRow.Get()
return private.objectPool:Get()
end
3 years ago
-- ============================================================================
-- DatabaseRow Meta Methods
-- ============================================================================
3 years ago
---@class DatabaseRow
local ROW_METHODS = {}
function ROW_METHODS:_Acquire(db, query, newRowUUID)
local context = private.context[self]
context.db = db
context.query = query
context.isNewRow = newRowUUID and true or false
if newRowUUID then
context.uuid = newRowUUID
end
db:_GetListFields(context.listFields)
end
3 years ago
function ROW_METHODS:_Release()
local context = private.context[self]
context.db = nil
context.query = nil
context.isNewRow = nil
context.uuid = nil
assert(not context.pendingChanges)
wipe(self)
end
3 years ago
-- ============================================================================
-- DatabaseRow Public Methods
-- ============================================================================
3 years ago
---Releases a database row.
function ROW_METHODS:Release()
self:_Release()
private.objectPool:Recycle(self)
end
3 years ago
---Gets the row's UUID.
---@return number
function ROW_METHODS:GetUUID()
local uuid = private.context[self].uuid
assert(uuid)
return uuid
end
3 years ago
---Gets the query which owns the row.
---@return DatabaseQuery
function ROW_METHODS:GetQuery()
local query = private.context[self].query
assert(query)
return query
end
3 years ago
---Gets a field from the row.
---@param field string The name of the field
---@return any @The value of the field
function ROW_METHODS:GetField(field)
local context = private.context[self]
if context.listFields[field] then
if context.pendingChanges and context.pendingChanges[field] then
return unpack(context.pendingChanges[field])
elseif context.query then
-- use the query to lookup the result
return context.query:_GetResultRowData(context.uuid, field)
else
-- we're not tied to a query so this should be a local DB field
return context.db:GetRowFieldByUUID(context.uuid, field)
end
else
return self[field]
end
end
3 years ago
---Gets up to 10 fields at a time from the row.
---@param ... string The name of the fields
---@return ... @The value of the fields
function ROW_METHODS:GetFields(...)
local numFields = select("#", ...)
local field1, field2, field3, field4, field5, field6, field7, field8, field9, field10 = ...
if numFields == 0 then
return
elseif numFields == 1 then
return self[field1]
elseif numFields == 2 then
return self[field1], self[field2]
elseif numFields == 3 then
return self[field1], self[field2], self[field3]
elseif numFields == 4 then
return self[field1], self[field2], self[field3], self[field4]
elseif numFields == 5 then
return self[field1], self[field2], self[field3], self[field4], self[field5]
elseif numFields == 6 then
return self[field1], self[field2], self[field3], self[field4], self[field5], self[field6]
elseif numFields == 7 then
return self[field1], self[field2], self[field3], self[field4], self[field5], self[field6], self[field7]
elseif numFields == 8 then
return self[field1], self[field2], self[field3], self[field4], self[field5], self[field6], self[field7], self[field8]
elseif numFields == 9 then
return self[field1], self[field2], self[field3], self[field4], self[field5], self[field6], self[field7], self[field8], self[field9]
elseif numFields == 10 then
return self[field1], self[field2], self[field3], self[field4], self[field5], self[field6], self[field7], self[field8], self[field9], self[field10]
else
error("GetFields() only supports up to 10 fields")
end
end
3 years ago
---Calcules a hash of the specified fields from the row.
---@param fields string[] The fields to calculate the hash of
---@return number @The hash
function ROW_METHODS:CalculateHash(fields)
local hash = nil
for _, field in ipairs(fields) do
hash = Math.CalculateHash(self[field], hash)
end
return hash
end
3 years ago
---Sets the value of a field in the row.
---@param field string The name of the field
---@param value any The value for the field
---@return DatabaseRow
function ROW_METHODS:SetField(field, value)
local context = private.context[self]
local listFieldType = context.db:_GetListFieldType(field)
local isSameValue = not listFieldType and not context.isNewRow and value == self[field]
if isSameValue and not context.pendingChanges then
-- setting to the same value, so ignore this call
return self
end
if context.db:_IsSmartMapField(field) then
error(format("Cannot set smart map field (%s)", tostring(field)), 3)
end
local fieldType = context.db:_GetFieldType(field)
if not fieldType then
error(format("Field %s doesn't exist", tostring(field)), 3)
elseif listFieldType then
local len = #value
if type(value) ~= "table" then
error(format("Expected list value, got %s", type(value)), 2)
end
for i, v in pairs(value) do
if type(i) ~= "number" or i < 1 or i > len then
error("Invalid table index: "..tostring(i), 2)
elseif type(v) ~= listFieldType then
error(format("List (%s) entries should be of type %s, got %s", tostring(field), listFieldType, tostring(v)), 2)
end
end
elseif fieldType ~= type(value) then
error(format("Field %s should be a %s, got %s", tostring(field), tostring(fieldType), type(value)), 2)
end
if isSameValue then
-- setting the field to its original value, so clear any pending change
context.pendingChanges[field] = nil
if not next(context.pendingChanges) then
TempTable.Release(context.pendingChanges)
context.pendingChanges = nil
3 years ago
end
else
context.pendingChanges = context.pendingChanges or TempTable.Acquire()
context.pendingChanges[field] = value
end
return self
end
3 years ago
---Creates the row within the database and releases the row.
function ROW_METHODS:Create()
local context = private.context[self]
assert(context.isNewRow and context.pendingChanges)
-- make sure all the fields are set
for field in context.db:FieldIterator() do
assert(context.pendingChanges[field] ~= nil)
end
-- apply all the pending changes
for field, value in pairs(context.pendingChanges) do
if not context.listFields[field] then
rawset(self, field, value)
end
end
local values = context.pendingChanges
context.pendingChanges = nil
context.isNewRow = nil
context.db:_InsertRow(self, values)
TempTable.Release(values)
end
---Updates the row within the database.
---@return DatabaseRow
function ROW_METHODS:Update()
local context = private.context[self]
assert(not context.isNewRow)
if not context.pendingChanges then
return
end
-- apply all the pending changes
local changeContext = TempTable.Acquire()
for field, value in pairs(context.pendingChanges) do
if context.listFields[field] then
changeContext[field] = value
else
changeContext[field] = self[field]
-- cache this new value
rawset(self, field, value)
end
end
TempTable.Release(context.pendingChanges)
context.pendingChanges = nil
context.db:_UpdateRow(self, changeContext)
TempTable.Release(changeContext)
return self
end
---Either creates or updates the row within the database and releases it.
function ROW_METHODS:CreateOrUpdateAndRelease()
local context = private.context[self]
if context.isNewRow then
self:Create()
else
self:Update()
self:Release()
end
end
---Returns a clone of the database row.
---@return DatabaseRow
function ROW_METHODS:Clone()
local context = private.context[self]
assert(not context.isNewRow and not context.pendingChanges)
local newRow = QueryResultRow.Get()
newRow:_Acquire(context.db)
newRow:_SetUUID(context.uuid)
return newRow
end
-- ============================================================================
-- DatabaseRow Private Methods
-- ============================================================================
function ROW_METHODS:_SetUUID(uuid)
local context = private.context[self]
context.uuid = uuid
wipe(self)
end
-- ============================================================================
-- DatabaseRow Metatable
-- ============================================================================
3 years ago
local ROW_MT = {
-- getter
__index = function(self, key)
if key == nil then
error("Attempt to get nil key")
end
if ROW_METHODS[key] then
return ROW_METHODS[key]
3 years ago
end
-- cache the value
local context = private.context[self]
if context.isNewRow then
error("Getting value on a new row: "..tostring(key))
elseif context.listFields[key] then
error("Cannot get list fields by indexing row")
3 years ago
end
local result = nil
if context.query then
-- use the query to lookup the result
result = context.query:_GetResultRowData(context.uuid, key)
else
-- we're not tied to a query so this should be a local DB field
if not context.db:_GetFieldType(key) then
error("Invalid field: "..tostring(key), 2)
end
result = context.db:GetRowFieldByUUID(context.uuid, key)
end
if result ~= nil then
rawset(self, key, result)
end
return result
end,
-- setter
__newindex = function(self, key, value)
error("Table is read-only", 2)
end,
__eq = function(self, other)
local uuid = private.context[self].uuid
local uuidOther = private.context[other].uuid
return uuid and uuidOther and uuid == uuidOther
end,
__tostring = function(self)
local context = private.context[self]
return "QueryResultRow:"..strmatch(tostring(context), "table:[^0-9a-fA-F]*([0-9a-fA-F]+)")..":"..self:GetUUID()
end,
__metatable = false,
}
-- ============================================================================
-- Private Helper Functions
-- ============================================================================
function private.CreateNew()
local row = setmetatable({}, ROW_MT)
private.context[row] = {
db = nil,
query = nil,
isNewRow = nil,
uuid = nil,
listFields = {},
3 years ago
}
return row
end