mirror of
https://github.com/veekun/pokedex.git
synced 2024-08-20 18:16:34 +00:00
load: Use COPY FROM STDIN on PostgreSQL.
COPY FROM FILE requires database superuser permissions, because of the obvious security implications. COPY FROM STDIN has no such restriction. Also do some cleanup while we're here.
This commit is contained in:
parent
33fab44d0d
commit
93988d966c
1 changed files with 29 additions and 31 deletions
|
@ -144,22 +144,23 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
|
||||||
|
|
||||||
table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
|
table_objs = sqlalchemy.sql.util.sort_tables(table_objs)
|
||||||
|
|
||||||
|
engine = session.get_bind()
|
||||||
|
|
||||||
# Limit table names to 30 characters for Oracle
|
# Limit table names to 30 characters for Oracle
|
||||||
oracle = (session.connection().dialect.name == 'oracle')
|
oracle = (engine.dialect.name == 'oracle')
|
||||||
if oracle:
|
if oracle:
|
||||||
rewrite_long_table_names()
|
rewrite_long_table_names()
|
||||||
|
|
||||||
# SQLite speed tweaks
|
# SQLite speed tweaks
|
||||||
if not safe and session.connection().dialect.name == 'sqlite':
|
if not safe and engine.dialect.name == 'sqlite':
|
||||||
session.connection().execute("PRAGMA synchronous=OFF")
|
session.execute("PRAGMA synchronous=OFF")
|
||||||
session.connection().execute("PRAGMA journal_mode=OFF")
|
session.execute("PRAGMA journal_mode=OFF")
|
||||||
|
|
||||||
# Drop all tables if requested
|
# Drop all tables if requested
|
||||||
if drop_tables:
|
if drop_tables:
|
||||||
bind = session.get_bind()
|
|
||||||
print_start('Dropping tables')
|
print_start('Dropping tables')
|
||||||
for n, table in enumerate(reversed(table_objs)):
|
for n, table in enumerate(reversed(table_objs)):
|
||||||
table.drop(checkfirst=True)
|
table.drop(bind=engine, checkfirst=True)
|
||||||
|
|
||||||
# Drop columns' types if appropriate; needed for enums in
|
# Drop columns' types if appropriate; needed for enums in
|
||||||
# postgresql
|
# postgresql
|
||||||
|
@ -169,7 +170,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
drop(bind=bind, checkfirst=True)
|
drop(bind=engine, checkfirst=True)
|
||||||
|
|
||||||
print_status('%s/%s' % (n, len(table_objs)))
|
print_status('%s/%s' % (n, len(table_objs)))
|
||||||
print_done()
|
print_done()
|
||||||
|
@ -179,7 +180,6 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
|
||||||
table.create()
|
table.create()
|
||||||
print_status('%s/%s' % (n, len(table_objs)))
|
print_status('%s/%s' % (n, len(table_objs)))
|
||||||
print_done()
|
print_done()
|
||||||
connection = session.connection()
|
|
||||||
|
|
||||||
# Okay, run through the tables and actually load the data now
|
# Okay, run through the tables and actually load the data now
|
||||||
for table_obj in table_objs:
|
for table_obj in table_objs:
|
||||||
|
@ -205,35 +205,34 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
|
||||||
reader = csv.reader(csvfile, lineterminator='\n')
|
reader = csv.reader(csvfile, lineterminator='\n')
|
||||||
column_names = [unicode(column) for column in reader.next()]
|
column_names = [unicode(column) for column in reader.next()]
|
||||||
|
|
||||||
if not safe and session.connection().dialect.name == 'postgresql':
|
if not safe and engine.dialect.name == 'postgresql':
|
||||||
"""
|
# Postgres' CSV dialect works with our data, if we mark the not-null
|
||||||
Postgres' CSV dialect works with our data, if we mark the not-null
|
# columns with FORCE NOT NULL.
|
||||||
columns with FORCE NOT NULL.
|
|
||||||
COPY is only allowed for DB superusers. If you're not one, use safe
|
|
||||||
loading (pokedex load -S).
|
|
||||||
"""
|
|
||||||
session.commit()
|
|
||||||
not_null_cols = [c for c in column_names if not table_obj.c[c].nullable]
|
not_null_cols = [c for c in column_names if not table_obj.c[c].nullable]
|
||||||
if not_null_cols:
|
if not_null_cols:
|
||||||
force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols)
|
force_not_null = 'FORCE NOT NULL ' + ','.join('"%s"' % c for c in not_null_cols)
|
||||||
else:
|
else:
|
||||||
force_not_null = ''
|
force_not_null = ''
|
||||||
command = "COPY %(table_name)s (%(columns)s) FROM '%(csvpath)s' CSV HEADER %(force_not_null)s"
|
|
||||||
session.connection().execute(
|
# Grab the underlying psycopg2 cursor so we can use COPY FROM STDIN
|
||||||
|
raw_conn = engine.raw_connection()
|
||||||
|
command = "COPY %(table_name)s (%(columns)s) FROM STDIN CSV HEADER %(force_not_null)s"
|
||||||
|
csvfile.seek(0)
|
||||||
|
raw_conn.cursor().copy_expert(
|
||||||
command % dict(
|
command % dict(
|
||||||
table_name=table_name,
|
table_name=table_name,
|
||||||
csvpath=csvpath,
|
|
||||||
columns=','.join('"%s"' % c for c in column_names),
|
columns=','.join('"%s"' % c for c in column_names),
|
||||||
force_not_null=force_not_null,
|
force_not_null=force_not_null,
|
||||||
)
|
),
|
||||||
|
csvfile,
|
||||||
)
|
)
|
||||||
session.commit()
|
raw_conn.commit()
|
||||||
print_done()
|
print_done()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Self-referential tables may contain rows with foreign keys of other
|
# Self-referential tables may contain rows with foreign keys of other
|
||||||
# rows in the same table that do not yet exist. Pull these out and add
|
# rows in the same table that do not yet exist. Pull these out and
|
||||||
# them to the session last
|
# insert them last
|
||||||
# ASSUMPTION: Self-referential tables have a single PK called "id"
|
# ASSUMPTION: Self-referential tables have a single PK called "id"
|
||||||
deferred_rows = [] # ( row referring to id, [foreign ids we need] )
|
deferred_rows = [] # ( row referring to id, [foreign ids we need] )
|
||||||
seen_ids = set() # primary keys we've seen
|
seen_ids = set() # primary keys we've seen
|
||||||
|
@ -248,7 +247,7 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
|
||||||
def insert_and_commit():
|
def insert_and_commit():
|
||||||
if not new_rows:
|
if not new_rows:
|
||||||
return
|
return
|
||||||
session.connection().execute(insert_stmt, new_rows)
|
session.execute(insert_stmt, new_rows)
|
||||||
session.commit()
|
session.commit()
|
||||||
new_rows[:] = []
|
new_rows[:] = []
|
||||||
|
|
||||||
|
@ -316,12 +315,12 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
|
||||||
raise ValueError("Too many levels of self-reference! "
|
raise ValueError("Too many levels of self-reference! "
|
||||||
"Row was: " + str(row))
|
"Row was: " + str(row))
|
||||||
|
|
||||||
session.connection().execute(
|
session.execute(
|
||||||
insert_stmt.values(**row_data)
|
insert_stmt.values(**row_data)
|
||||||
)
|
)
|
||||||
seen_ids.add(row_data['id'])
|
seen_ids.add(row_data['id'])
|
||||||
session.commit()
|
|
||||||
|
|
||||||
|
session.commit()
|
||||||
print_done()
|
print_done()
|
||||||
|
|
||||||
|
|
||||||
|
@ -333,18 +332,17 @@ def load(session, tables=[], directory=None, drop_tables=False, verbose=False, s
|
||||||
table_obj = translation_class.__table__
|
table_obj = translation_class.__table__
|
||||||
if table_obj in table_objs:
|
if table_obj in table_objs:
|
||||||
insert_stmt = table_obj.insert()
|
insert_stmt = table_obj.insert()
|
||||||
session.connection().execute(insert_stmt, rows)
|
session.execute(insert_stmt, rows)
|
||||||
session.commit()
|
session.commit()
|
||||||
# We don't have a total, but at least show some increasing number
|
# We don't have a total, but at least show some increasing number
|
||||||
new_row_count += len(rows)
|
new_row_count += len(rows)
|
||||||
print_status(str(new_row_count))
|
print_status(str(new_row_count))
|
||||||
|
|
||||||
print_done()
|
|
||||||
|
|
||||||
# SQLite check
|
# SQLite check
|
||||||
if session.connection().dialect.name == 'sqlite':
|
if engine.dialect.name == 'sqlite':
|
||||||
session.connection().execute("PRAGMA integrity_check")
|
session.execute("PRAGMA integrity_check")
|
||||||
|
|
||||||
|
print_done()
|
||||||
|
|
||||||
|
|
||||||
def dump(session, tables=[], directory=None, verbose=False, langs=None):
|
def dump(session, tables=[], directory=None, verbose=False, langs=None):
|
||||||
|
|
Loading…
Add table
Reference in a new issue