You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

699 lines
22 KiB

-- ------------------------------------------------------------------------------ --
-- TradeSkillMaster --
-- https://tradeskillmaster.com --
-- All Rights Reserved - Detailed license information included with addon. --
-- ------------------------------------------------------------------------------ --
local TSM = select(2, ...) ---@type TSM
local CodeGen = TSM.Init("Util.CustomStringClasses.CodeGen") ---@class Util.CustomStringClasses.CodeGen
local CustomStringCodeGen = TSM.Include("LibTSMClass").DefineClass("CustomStringCodeGen") ---@class CustomStringCodeGen
local Types = TSM.Include("Util.CustomStringClasses.Types")
local Math = TSM.Include("Util.Math")
local ItemString = TSM.Include("Util.ItemString")
local StringBuilder = TSM.Include("Util.StringBuilder")
local private = {
instance = nil,
}
-- ============================================================================
-- Code Templates
-- ============================================================================
local BASE_TEMPLATE =
[[return function(itemString, helpers)
local INVALID = helpers.INVALID
-- Locals
%(locals)s
-- Code
local res = INVALID
repeat
%(code)s
res = %(expression)s
until true
return res
end]]
local COMPARISON_FUNCTION_TEMPLATE =
[[if not %(res)s then repeat
%(res)s = INVALID
%(childrenCode)s
if %(leftRes)s %(operator)s %(rightRes)s then
%(trueCode)s
%(res)s = %(trueRes)s
else
%(falseCode)s
%(res)s = %(falseRes)s
end
until true end]]
local COMPARISON_COLLAPSED_FUNCTION_TEMPLATE =
[[if not %(res)s then repeat
%(res)s = INVALID
%(childrenCode)s
if %(leftRes)s %(operator)s %(rightRes)s then
%(trueCode)s
%(res)s = %(trueRes)s
%(childCode)s
else
%(falseCode)s
%(res)s = %(falseRes)s
end
until true end]]
local COMPARISON_COLLAPSED_CHILD_TEMPLATE =
[[ elseif %(leftRes)s %(operator)s %(rightRes)s then
%(trueCode)s
%(res)s = %(trueRes)s]]
local FIRST_FUNCTION_TEMPLATE =
[[if not %(res)s then repeat
%(res)s = INVALID
%(code)s
until true end]]
local FIRST_FUNCTION_ARG_TEMPLATE =
[[%(res)s = %(arg)s
break]]
local FIRST_FUNCTION_ARG_SUFFIX_TEMPLATE =
[[if %(res)s ~= INVALID then break end]]
local MIN_MAX_FUNCTION_ARG_TEMPLATE =
[[if %(res)s == INVALID or (%(arg)s ~= INVALID and %(arg)s %(operator)s %(res)s) then
%(res)s = %(arg)s
end]]
local MIN_MAX_FUNCTION_TEMPLATE =
[[if not %(res)s then
%(res)s = INVALID
%(code)s
end]]
local AVG_FUNCTION_ARG_TEMPLATE =
[[%(total)s = %(total)s + %(arg)s
%(num)s = %(num)s + 1]]
local AVG_FUNCTION_TEMPLATE =
[[if not %(res)s then
local %(total)s, %(num)s = 0, 0
%(code)s
%(res)s = %(num)s > 0 and %(total)s / %(num)s or INVALID
end]]
local ROUND_FUNCTION_TEMPLATE =
[[if not %(res)s then repeat
%(res)s = INVALID
%(childrenCode)s
%(res)s = %(func)s(%(value)s / %(sig)s%(extraAdd)s) * %(sig)s
until true end]]
local CONVERT_FUNCTION_TEMPLATE =
[[%(res)s = %(res)s or helpers.GetPrice(%(item)s, "%(value)s", "%(source)s")]]
local SOURCE_WITH_ITEM_ARG_TEMPLATE =
[[%(res)s = %(res)s or helpers.GetPrice(%(item)s, "%(value)s")]]
-- ============================================================================
-- Module Loading
-- ============================================================================
CodeGen:OnModuleLoad(function()
private.instance = CustomStringCodeGen()
end)
-- ============================================================================
-- Module Functions
-- ============================================================================
---Generates code for a given AST.
---@param tree Tree The AST
---@return string|nil # The generated code or nil if there was an error
---@return EnumTypeValue|nil # The error type
---@return number|nil # The error token index
function CodeGen.Execute(tree)
return private.instance:Generate(tree)
end
-- ============================================================================
-- CustomStringCodeGen Class Methods
-- ============================================================================
function CustomStringCodeGen:__init()
self._tree = nil
self._errType = nil
self._errTokenIndex = nil
self._hash = {}
self._statement = {}
self._expression = {}
self._localVar = {}
self._locals = {}
end
function CustomStringCodeGen:Generate(tree)
assert(not self._tree and not self._errType and not self._errTokenIndex)
self._tree = tree
local code = self:GenerateHelper()
wipe(self._hash)
wipe(self._statement)
wipe(self._expression)
wipe(self._localVar)
wipe(self._locals)
self._tree = nil
local errType, errTokenIndex = self._errType, self._errTokenIndex
self._errType = nil
self._errTokenIndex = nil
if code then
assert(not errType and not errTokenIndex)
else
assert(errType and errTokenIndex)
end
return code, errType, errTokenIndex
end
function CustomStringCodeGen:GenerateHelper()
local root = self._tree:GetRoot()
if not self:GenerateCode(root) then
return nil
end
-- Sort the locals
for _, varName in pairs(self._localVar) do
if not self._locals[varName] then
tinsert(self._locals, varName)
self._locals[varName] = true
end
end
if #self._locals > 190 then
self._errType = Types.ERROR.TOO_MANY_VARS
self._errTokenIndex = -1
return nil
end
sort(self._locals)
for i = 1, #self._locals do
self._locals[i] = format("local %s = nil", self._locals[i])
end
-- Some final formatting to make the code cleaner
local statement = self._statement[root] or ""
statement = gsub(statement, "\n%s+\n", "\n\n")
statement = gsub(statement, " then\n\n", " then\n")
statement = gsub(statement, "\telse\n\n", "\telse\n")
statement = gsub(statement, "\trepeat\n\n", "\trepeat\n")
statement = gsub(statement, "\n(\n\t+elseif)", "%1")
statement = gsub(statement, "^\n+", "")
return StringBuilder.Get(BASE_TEMPLATE)
:SetParam("locals", private.IndentCode(table.concat(self._locals, "\n")))
:SetParam("code", private.IndentCode(statement, 2))
:SetParam("expression", self._expression[root])
:Commit()
end
function CustomStringCodeGen:GetHash(node)
if not self._hash[node] then
local nodeType = self._tree:GetData(node, "type")
local nodeValue = self._tree:GetData(node, "value")
local hash = Math.CalculateHash(tostring(nodeType))
hash = Math.CalculateHash(tostring(nodeValue), hash)
for child in self._tree:ChildrenIterator(node) do
assert(nodeType == Types.NODE.FUNCTION)
hash = Math.CalculateHash(self:GetHash(child), hash)
end
self._hash[node] = hash
end
return self._hash[node]
end
function CustomStringCodeGen:GenerateCode(node)
local nodeType = self._tree:GetData(node, "type")
local nodeValue = self._tree:GetData(node, "value")
if nodeType == Types.NODE.CONSTANT then
self._statement[node] = nil
self._expression[node] = nodeValue
self._localVar[node] = nil
elseif nodeType == Types.NODE.INVALID then
self._statement[node] = nil
self._expression[node] = "INVALID"
self._localVar[node] = nil
elseif nodeType == Types.NODE.VARIABLE then
if nodeValue == "baseitem" then
local varName = "baseItemString"
self._statement[node] = StringBuilder.Get("%(res)s = %(res)s or helpers.GetBaseItemString(itemString)")
:SetParam("res", varName)
:Commit()
self._expression[node] = varName
self._localVar[node] = varName
elseif Types.IsItemStringParam(nodeValue) then
local itemString = ItemString.Get(nodeValue)
if not itemString then
self:HandleError(Types.ERROR.INVALID_ITEM_STRING, node)
return false
end
self._statement[node] = nil
self._expression[node] = "\""..itemString.."\""
self._localVar[node] = nil
else
local varName = "var_"..nodeValue
self._statement[node] = StringBuilder.Get("%(res)s = %(res)s or helpers.GetPrice(itemString, \"%(value)s\")")
:SetParam("res", varName)
:SetParam("value", nodeValue)
:Commit()
self._expression[node] = varName
self._localVar[node] = varName
end
elseif nodeType == Types.NODE.FUNCTION then
for child in self._tree:ChildrenIterator(node) do
if not self:GenerateCode(child) then
return false
end
end
if nodeValue == "+" or nodeValue == "-" or nodeValue == "*" or nodeValue == "/" or nodeValue == "^" or nodeValue == "%" then
local statement, expression = self:GenerateMathOperation(node, nodeValue)
self._statement[node] = statement
self._expression[node] = expression
self._localVar[node] = nil
else
local resLocal = "res_"..nodeValue.."_"..self:GetHash(node)
self._expression[node] = resLocal
self._localVar[node] = resLocal
local statement = nil
if nodeValue == "iflte" or nodeValue == "iflt" or nodeValue == "ifgte" or nodeValue == "ifgt" or nodeValue == "ifeq" then
statement = self:GenerateComparisonStatement(node, nodeValue, resLocal)
elseif nodeValue == "first" or nodeValue == "min" or nodeValue == "max" or nodeValue == "avg" then
statement = self:GenerateVaragFunctionStatement(node, nodeValue, resLocal)
elseif nodeValue == "round" or nodeValue == "roundup" or nodeValue == "rounddown" then
statement = self:GenerateRoundFunctionStatement(node, nodeValue, resLocal)
elseif nodeValue == "convert" then
statement = self:GenerateConvertFunctionStatement(node, nodeValue, resLocal)
else
-- Assume this is a source with an item argument
statement = self:GenerateSourceWithItemArgStatement(node, nodeValue, resLocal)
end
if not statement then
return false
end
self._statement[node] = statement
end
else
error("Invalid node type: "..tostring(nodeType))
end
return true
end
function CustomStringCodeGen:GenerateMathOperation(node, nodeValue)
assert(self._tree:GetNumChildren(node) == 2)
local leftNode, rightNode = self._tree:GetChildren(node)
assert(self._expression[leftNode] and self._expression[rightNode])
local statement = ""
if self._statement[leftNode] then
statement = statement.."\n\n"
if self._localVar[leftNode] then
statement = statement..self:StatementHelper(leftNode)
else
statement = statement..self._statement[leftNode]
end
end
if self._statement[rightNode] then
statement = statement.."\n\n"
if self._localVar[rightNode] then
statement = statement..self:StatementHelper(rightNode)
else
statement = statement..self._statement[rightNode]
end
end
local expression = StringBuilder.Get("(%(left)s %(operator)s %(right)s)")
:SetParam("operator", nodeValue == "%" and "/ 100 *" or nodeValue)
:SetParam("left", self._expression[leftNode])
:SetParam("right", self._expression[rightNode])
:Commit()
return statement, expression
end
function CustomStringCodeGen:GenerateComparisonStatement(node, nodeValue, resLocal)
local numChildren = self._tree:GetNumChildren(node)
if numChildren < 3 then
self:HandleError(Types.ERROR.INVALID_NUM_ARGS, node)
return nil
end
assert(numChildren >= 3)
local leftNode, rightNode, trueNode, falseNode = self._tree:GetChildren(node)
local leftCode, leftRes = self:StatementHelper(leftNode, resLocal.."_leftArg")
local rightCode, rightRes = self:StatementHelper(rightNode, resLocal.."_rightArg")
-- Optimize cases where the true / false value is equal to one of the comparison values
local trueCode, trueRes = "", nil
if self:GetHash(leftNode) == self:GetHash(trueNode) then
trueCode = "-- Optimized to leftArg"
trueRes = leftRes
elseif self:GetHash(rightNode) == self:GetHash(trueNode) then
trueCode = "-- Optimized to rightArg"
trueRes = rightRes
else
trueCode, trueRes = self:StatementHelper(trueNode, resLocal.."_trueArg")
end
local falseCode, falseRes = "", nil
if falseNode then
if self:GetHash(leftNode) == self:GetHash(falseNode) then
falseCode = "-- Optimized to leftArg"
falseRes = leftRes
elseif self:GetHash(rightNode) == self:GetHash(falseNode) then
falseCode = "-- Optimized to rightArg"
falseRes = rightRes
else
falseCode, falseRes = self:StatementHelper(falseNode, resLocal.."_falseArg")
end
else
falseCode = "-- Not specified"
falseRes = "INVALID"
end
local operator = nil
if nodeValue == "iflt" then
operator = "<"
elseif nodeValue == "iflte" then
operator = "<="
elseif nodeValue == "ifgt" then
operator = ">"
elseif nodeValue == "ifgte" then
operator = ">="
elseif nodeValue == "ifeq" then
operator = "=="
else
error("Invalid node value: "..tostring(nodeValue))
end
local comparisonChildrenCode = private.JoinLines("", leftCode, "", rightCode, "")
local template, childCode = nil, nil
if numChildren > 4 then
childCode = ""
local i = 4
while true do
if numChildren < i + 2 then
self:HandleError(Types.ERROR.INVALID_NUM_ARGS, node)
return nil
end
local childRightNode, childTrueNode, childFalseNode = select(i, self._tree:GetChildren(node))
local childRightCode, childRightRes = self:StatementHelper(childRightNode, resLocal.."_childRightArg")
assert(childRightCode == "")
local childTrueCode, childTrueRes = self:StatementHelper(childTrueNode, resLocal.."_trueArg")
childCode = childCode.."\n"..StringBuilder.Get(COMPARISON_COLLAPSED_CHILD_TEMPLATE)
:SetParam("res", resLocal)
:SetParam("operator", operator)
:SetParam("leftRes", leftRes)
:SetParam("rightRes", childRightRes)
:SetParam("trueCode", private.IndentCode(childTrueCode))
:SetParam("trueRes", childTrueRes)
:Commit()
i = i + 2
if not childFalseNode then
falseCode = "-- Not specified"
falseRes = "INVALID"
break
elseif i >= numChildren then
falseCode, falseRes = self:StatementHelper(childFalseNode, resLocal.."_falseArg")
break
end
end
template = COMPARISON_COLLAPSED_FUNCTION_TEMPLATE
else
template = COMPARISON_FUNCTION_TEMPLATE
end
local builder = StringBuilder.Get(template)
:SetParam("res", resLocal)
:SetParam("operator", operator)
:SetParam("childrenCode", private.IndentCode(comparisonChildrenCode))
:SetParam("leftRes", leftRes)
:SetParam("rightRes", rightRes)
:SetParam("trueCode", private.IndentCode(trueCode, 2))
:SetParam("trueRes", trueRes)
:SetParam("falseCode", private.IndentCode(falseCode, 2))
:SetParam("falseRes", falseRes)
if childCode then
builder:SetParam("childCode", childCode)
end
return builder:Commit()
end
function CustomStringCodeGen:GenerateVaragFunctionStatement(node, functionName, resLocal)
assert(self._tree:GetNumChildren(node) > 0)
local baseArgTemplate, functionTemplate, checkArg = nil, nil, nil
local operator = nil
local totalLocal, numLocal = nil, nil
if functionName == "min" then
baseArgTemplate = MIN_MAX_FUNCTION_ARG_TEMPLATE
functionTemplate = MIN_MAX_FUNCTION_TEMPLATE
operator = "<"
checkArg = false
elseif functionName == "max" then
baseArgTemplate = MIN_MAX_FUNCTION_ARG_TEMPLATE
functionTemplate = MIN_MAX_FUNCTION_TEMPLATE
operator = ">"
checkArg = false
elseif functionName == "avg" then
baseArgTemplate = AVG_FUNCTION_ARG_TEMPLATE
functionTemplate = AVG_FUNCTION_TEMPLATE
totalLocal = resLocal.."_total"
numLocal = resLocal.."_num"
checkArg = true
elseif functionName == "first" then
baseArgTemplate = FIRST_FUNCTION_ARG_TEMPLATE
functionTemplate = FIRST_FUNCTION_TEMPLATE
checkArg = true
else
error("Invalid function name: "..tostring(functionName))
end
local childrenCode = ""
local argNum = 1
for child in self._tree:ChildrenIterator(node) do
local argTemplate, argLocal, didWrap = self:HandleFunctionArgument(child, baseArgTemplate, resLocal.."_arg_"..argNum, checkArg)
if functionName == "first" and didWrap then
argTemplate = private.JoinLines(argTemplate, FIRST_FUNCTION_ARG_SUFFIX_TEMPLATE)
end
local builder = StringBuilder.Get(argTemplate)
:SetParam("res", resLocal)
:SetParam("arg", argLocal)
if functionName == "min" or functionName == "max" then
builder:SetParam("operator", operator)
elseif functionName == "avg" then
builder:SetParam("total", totalLocal)
builder:SetParam("num", numLocal)
end
childrenCode = private.JoinLines(childrenCode, "", builder:Commit())
argNum = argNum + 1
end
local builder = StringBuilder.Get(functionTemplate)
:SetParam("res", resLocal)
:SetParam("code", private.IndentCode(childrenCode))
if functionName == "avg" then
builder:SetParam("total", totalLocal)
builder:SetParam("num", numLocal)
end
return builder:Commit()
end
function CustomStringCodeGen:GenerateRoundFunctionStatement(node, nodeValue, resLocal)
local numChildren = self._tree:GetNumChildren(node)
if numChildren == 0 or numChildren > 2 then
self:HandleError(Types.ERROR.INVALID_NUM_ARGS, node)
return nil
end
local valueNode, sigNode = self._tree:GetChildren(node)
local valueCode, valueRes = self:StatementHelper(valueNode, resLocal.."_value")
local sigCode, sigRes = nil, nil
if sigNode then
sigCode, sigRes = self:StatementHelper(sigNode, resLocal.."_sig")
else
sigCode = ""
sigRes = "1"
end
local func, extraAdd = nil, nil
if nodeValue == "round" then
func = "floor"
extraAdd = " + 0.5"
elseif nodeValue == "rounddown" then
func = "floor"
-- Add a bit in order to account for floating point inaccuracies
extraAdd = " + 0.001 / "..sigRes
elseif nodeValue == "roundup" then
-- Subtract a bit in order to account for floating point inaccuracies
func = "ceil"
extraAdd = " - 0.001 / "..sigRes
else
error("Invalid node value: "..tostring(nodeValue))
end
local childrenCode = private.JoinLines("", valueCode, "", sigCode, "")
return StringBuilder.Get(ROUND_FUNCTION_TEMPLATE)
:SetParam("res", resLocal)
:SetParam("childrenCode", private.IndentCode(childrenCode))
:SetParam("func", func)
:SetParam("value", valueRes)
:SetParam("sig", sigRes)
:SetParam("extraAdd", extraAdd)
:Commit()
end
function CustomStringCodeGen:GenerateConvertFunctionStatement(node, nodeValue, resLocal)
-- Convert must have a source argument and may optionally have a "baseitem" or itemString parameter
local numChildren = self._tree:GetNumChildren(node)
if numChildren == 0 or numChildren > 2 then
self:HandleError(Types.ERROR.INVALID_NUM_ARGS, node)
return nil
end
local sourceNode, itemNode = self._tree:GetChildren(node)
assert(self._tree:GetData(sourceNode, "type") == Types.NODE.VARIABLE)
local sourceValue = self._tree:GetData(sourceNode, "value")
local code, item = nil, nil
if itemNode then
code = self:ItemArgStatementHelper(itemNode)
item = self._expression[itemNode]
else
code = ""
item = "itemString"
end
return code..StringBuilder.Get(CONVERT_FUNCTION_TEMPLATE)
:SetParam("res", resLocal)
:SetParam("item", item)
:SetParam("source", sourceValue)
:SetParam("value", nodeValue)
:Commit()
end
function CustomStringCodeGen:GenerateSourceWithItemArgStatement(node, nodeValue, resLocal)
local numChildren = self._tree:GetNumChildren(node)
if numChildren ~= 1 then
self:HandleError(Types.ERROR.INVALID_NUM_ARGS, node)
return nil
end
local itemNode = self._tree:GetChildren(node)
local code = self:ItemArgStatementHelper(itemNode)
if not code then
return nil
end
return code..StringBuilder.Get(SOURCE_WITH_ITEM_ARG_TEMPLATE)
:SetParam("res", resLocal)
:SetParam("item", self._expression[itemNode])
:SetParam("value", nodeValue)
:Commit()
end
function CustomStringCodeGen:ItemArgStatementHelper(node)
if self._tree:GetData(node, "type") ~= Types.NODE.VARIABLE then
self:HandleError(Types.ERROR.INVALID_TOKEN, node)
return nil
end
local itemValue = self._tree:GetData(node, "value")
if not Types.IsItemParam(itemValue) then
self:HandleError(Types.ERROR.INVALID_TOKEN, node)
return nil
end
if self._statement[node] then
return self._statement[node].."\n"
else
return ""
end
end
function CustomStringCodeGen:StatementHelper(node, argLocal, isOptional)
local code, res = "", nil
local nodeType = self._tree:GetData(node, "type")
if nodeType == Types.NODE.CONSTANT then
-- Always valid
res = self._expression[node]
elseif self._localVar[node] then
if self._statement[node] then
code = code..self._statement[node]
end
if not isOptional then
code = code.."\n"..StringBuilder.Get("if %(arg)s == INVALID then break end")
:SetParam("arg", self._localVar[node])
:Commit()
end
res = self._localVar[node]
else
if self._statement[node] then
code = code..self._statement[node].."\n"
end
code = code.."\n"..StringBuilder.Get("local %(arg)s = %(expression)s")
:SetParam("arg", argLocal)
:SetParam("expression", self._expression[node])
:Commit()
res = argLocal
end
return code, res
end
function CustomStringCodeGen:HandleFunctionArgument(node, template, argLocal, checkArg)
local didWrap = false
if self._tree:GetData(node, "type") == Types.NODE.CONSTANT then
argLocal = self._expression[node]
else
if checkArg then
-- Need to make sure the argument is valid
template = private.JoinLines(
"if %(arg)s ~= INVALID then",
private.IndentCode(template),
"end"
)
end
local argCode = nil
argCode, argLocal = self:StatementHelper(node, argLocal, true)
template = private.JoinLines(argCode, template)
if private.CodeCanBreak(argCode) then
-- Need to wrap in a "repeat ... until true" since the argument code can break
template = private.JoinLines(
"repeat",
private.IndentCode(template),
"until true"
)
didWrap = true
end
end
return template, argLocal, didWrap
end
function CustomStringCodeGen:HandleError(errType, node)
self._errType = errType
self._errTokenIndex = self._tree:GetData(node, "tokenIndex")
end
-- ============================================================================
-- Private Helper Functions
-- ============================================================================
function private.JoinLines(...)
return strjoin("\n", ...)
end
function private.IndentCode(code, num)
local indent = strrep("\t", num or 1)
return indent..gsub(code, "\n", "\n"..indent)
end
function private.CodeCanBreak(code)
if not strfind(code, "%sbreak%s") then
-- No break in code at all
return false
elseif strmatch(code, "^%s*repeat\n.+\n%s*until true$") or strmatch(code, "^%s*if [^\n]+ then repeat\n.+\n%s*until true end$") then
-- Code is wrapped in its own "repeat ... until true"
return false
end
return true
end