You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

219 lines
7.3 KiB

import sqlite3
import threading
import json
import os
import networkx as nx
class DBController:
def __init__(self, db_path=':memory:', max_connections=10):
self.db_path = db_path
self.lock = threading.Lock()
self.conn = sqlite3.connect(db_path, check_same_thread=False)
self.conn.row_factory = sqlite3.Row
self.cursor = self.conn.cursor()
self.graph = nx.DiGraph()
self.init_database()
def init_database(self):
disk_conn = sqlite3.connect('pokemon_forms.db')
disk_cursor = disk_conn.cursor()
# Create tables in the file-based database
self.create_pokemon_forms_table(disk_cursor)
# Commit changes to the file-based database
disk_conn.commit()
# Copy the file-based database to the in-memory database
disk_conn.backup(self.conn)
# Close the file-based database connection
disk_conn.close()
if os.path.exists("pokemon_evolution_graph.json"):
with open("pokemon_evolution_graph.json", "r") as f:
data = json.load(f)
self.graph = nx.node_link_graph(data)
def save_changes(self):
with self.lock:
# Count the number of records before backup for verification
self.cursor.execute('SELECT COUNT(*) FROM pokemon_forms')
count = self.cursor.fetchone()[0]
print(f"Records in memory before backup: {count}")
# Back up the master connection to disk
disk_conn = sqlite3.connect('pokemon_forms.db')
with disk_conn:
self.conn.backup(disk_conn)
disk_conn.close()
data = nx.node_link_data(self.graph)
with open("pokemon_evolution_graph.json", "w") as f:
json.dump(data, f)
def close(self):
self.save_changes()
self.conn.close()
def create_pokemon_forms_table(self, cursor):
cursor.execute('''
CREATE TABLE IF NOT EXISTS pokemon_forms (
PFIC TEXT PRIMARY KEY,
data JSON NOT NULL
)
''')
def add_pokemon_form(self, pfic, name, form_name, national_dex, generation, sprite_url, gender_relevant):
data = {
"name": name,
"form_name": form_name,
"national_dex": national_dex,
"generation": generation,
"sprite_url": sprite_url,
"is_baby_form": False,
"storable_in_home": False,
"gender_relevant": gender_relevant
}
with self.lock:
self.cursor.execute('''
INSERT OR REPLACE INTO pokemon_forms (PFIC, data) VALUES (?, ?)
''', (pfic, json.dumps(data)))
self.conn.commit()
print(f"Added: {pfic}, {name}")
def craft_pokemon_json_query(self, fields_to_include, pfic = None):
query = f"SELECT "
extracts = []
for field in fields_to_include:
if field == "pfic":
extracts.append("PFIC as pfic")
else:
extracts.append(f"JSON_EXTRACT(data, '$.{field}') AS {field}")
query = query + ", ".join(extracts)
query = query + " FROM pokemon_forms"
if pfic is not None:
query = query + f" WHERE PFIC = '{pfic}'"
return query
def get_pokemon_details(self, pfic):
fields = [
"name",
"form_name",
"national_dex",
"generation",
"is_baby_form",
"storable_in_home",
"gender_relevant"
]
query = self.craft_pokemon_json_query(fields, pfic)
self.cursor.execute(query)
results = self.cursor.fetchone()
return dict(results)
def get_pokemon_details_by_name(self, name, fields):
query = self.craft_pokemon_json_query(fields)
name = name.replace("'", "''")
query += f" WHERE JSON_EXTRACT(data, '$.name') = '{name}'"
self.cursor.execute(query)
results = self.cursor.fetchall()
return [dict(row) for row in results]
def get_list_of_pokemon_forms(self):
fields = [
"pfic",
"name",
"form_name",
"national_dex",
"generation",
"is_baby_form",
"storable_in_home",
"gender_relevant"
]
query = self.craft_pokemon_json_query(fields)
self.cursor.execute(query)
results = self.cursor.fetchall()
return [dict(row) for row in results]
def update_home_status(self, pfic, status):
self.update_pokemon_field(pfic, "storable_in_home", status)
pass
def update_pokemon_field(self, pfic, field_name, new_value):
# Fetch the existing record
self.cursor.execute('SELECT data FROM pokemon_forms WHERE PFIC = ?', (pfic,))
result = self.cursor.fetchone()
if result:
# Load the JSON data and update the field
data = json.loads(result[0])
data[field_name] = new_value
# Update the record with the modified JSON
updated_data_str = json.dumps(data)
self.cursor.execute('''
UPDATE pokemon_forms
SET data = ?
WHERE PFIC = ?
''', (updated_data_str, pfic))
self.conn.commit()
def update_evolution_graph(self, evolutions):
for key in evolutions:
value = evolutions[key]
from_pfic = value["from_pfic"]
to_pfic = value["to_pfic"]
method = value["method"]
# Add nodes if they do not already exist
if not self.graph.has_node(from_pfic):
self.graph.add_node(from_pfic)
if not self.graph.has_node(to_pfic):
self.graph.add_node(to_pfic)
# Add the edge representing the evolution, with the method as an attribute
self.graph.add_edge(from_pfic, to_pfic, method=method)
def get_evolution_graph(self, pfic):
return list(self.graph.successors(pfic))
def get_evolution_paths(self, start_node):
paths = []
# Define a recursive function to traverse the graph
def traverse(current_node, current_path):
# Add the current node to the path as a tuple (node, None)
current_path.append((current_node, None))
# Get successors of the current node
successors = list(self.graph.successors(current_node))
if not successors:
# If there are no successors, add the current path to paths list
paths.append(current_path.copy())
else:
# Traverse each successor and add edge metadata
for successor in successors:
method = self.graph[current_node][successor]["method"]
# Add the successor node and method as a tuple (successor, method)
current_path.append((successor, method))
# Recur for the successor
traverse(successor, current_path)
# Backtrack (remove the last node and edge metadata)
current_path.pop()
# Remove the initial node tuple when backtracking fully
current_path.pop()
# Start traversal from the start_node
traverse(start_node, [])
return paths