Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"python-envs.pythonProjects": [],
"python-envs.defaultEnvManager": "ms-python.python:venv"
}
8 changes: 8 additions & 0 deletions config.default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,14 @@ db_url: "sqlite:///tmp_montage.db"
cookie_secret: ReplaceThisWithSomethingSomewhatSecret
superuser: Slaporte

database:
type: mysql
host: "localhost"
name: "enwiki_p"
username: "db_username"
password: "db_password"
read_default_file: "~/replica.my.cnf"

dev_local_cookie_value: "contact maintainers for details"
dev_remote_cookie_value: "contact maintainers for details"
oauth_secret_token: "see note below"
Expand Down
6 changes: 3 additions & 3 deletions montage/admin_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,11 +465,11 @@ def _prepare_round_params(coord_dao, request_dict):
if not val and column in req_columns:
raise InvalidAction('%s is required to create a round (got %r)'
% (column, val))
if column is 'vote_method' and val not in valid_vote_methods:
if column == 'vote_method' and val not in valid_vote_methods:
raise InvalidAction('%s is an invalid vote method' % val)
if column is 'deadline_date':
if column == 'deadline_date':
val = js_isoparse(val)
if column is 'jurors':
if column == 'jurors':
juror_names = val
rnd_dict[column] = val

Expand Down
7 changes: 3 additions & 4 deletions montage/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
MessageMiddleware,
SQLProfilerMiddleware)
from .rdb import Base, bootstrap_maintainers, ensure_series
from .utils import get_env_name, load_env_config
from .utils import DEFAULT_DB_URL, get_env_name, load_env_config
from .check_rdb import get_schema_errors, ping_connection

from .meta_endpoints import META_API_ROUTES, META_UI_ROUTES
Expand All @@ -39,7 +39,7 @@
from .cors import CORSMiddleware


DEFAULT_DB_URL = 'sqlite:///tmp_montage.db'

CUR_PATH = os.path.dirname(os.path.abspath(__file__))
PROJ_PATH = os.path.dirname(CUR_PATH)
STATIC_PATH = os.path.join(CUR_PATH, 'static')
Expand Down Expand Up @@ -119,8 +119,7 @@ def get_engine():
engine.echo = config.get('db_echo', False)
if not config.get('db_disable_ping'):
event.listen(engine, 'engine_connect', ping_connection)

if 'mysql' in db_url:
if config.get('database', {}).get('type') == 'mysql':
event.listen(engine, 'engine_connect', set_mysql_session_charset_and_collation)

return engine
Expand Down
73 changes: 72 additions & 1 deletion montage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from __future__ import absolute_import
from __future__ import print_function
import configparser
import sys
import bisect
import random
import getpass
import os.path
import datetime
from six.moves.urllib.parse import urlencode
from urllib.parse import urlencode


import yaml
Expand All @@ -21,6 +22,8 @@
from .check_rdb import get_schema_errors
import six

DEFAULT_DB_URL = 'sqlite:///tmp_montage.db'

try:
unicode = unicode
basestring = basestring
Expand Down Expand Up @@ -155,6 +158,74 @@ def load_env_config(env_name=None):

config['__env__'] = env_name
config['__file__'] = config_file_path


db_config = config.get('database')
if db_config and db_config.get('type') == 'mysql':
mysql_host = db_config.get('host', 'localhost')
db_name = db_config.get('name')

if not db_name:
raise ValueError("MySQL database name is required in config.")

username_to_use = None
password_to_use = None


default_file_path_from_config = db_config.get('read_default_file')


use_config_creds_as_fallback = True

if default_file_path_from_config:
expanded_path = os.path.expanduser(default_file_path_from_config)

if os.path.exists(expanded_path):
parser = configparser.ConfigParser()
try:
parser.read(expanded_path)

if 'client' in parser:
username_to_use = parser['client'].get('user')
password_to_use = parser['client'].get('password')


if username_to_use or password_to_use:
use_config_creds_as_fallback = False
print(f"++ Using MySQL credentials from default file: {expanded_path}", file=sys.stderr)
else:
print(f"!! Warning: No user/password found in {expanded_path}. Falling back to config.dev.yaml credentials.", file=sys.stderr)
except configparser.Error as e:
print(f"!! Error parsing MySQL default file {expanded_path}: {e}. Falling back to config.dev.yaml credentials.", file=sys.stderr)
else:
print(f"!! Warning: MySQL default file not found at {expanded_path}. Falling back to config.dev.yaml credentials.", file=sys.stderr)
else:
print("++ No MySQL default file specified. Using username/password from config.dev.yaml.", file=sys.stderr)


if use_config_creds_as_fallback:
username_to_use = db_config.get('username')
password_to_use = db_config.get('password')

config['mysql_connect_params'] = {
'host': mysql_host,
'database': db_name,

}

# Construct the db_url
credentials_part = ""
if username_to_use:
credentials_part = username_to_use
if password_to_use:
credentials_part += f":{password_to_use}"
credentials_part += "@"

config['db_url'] = f"mysql+pymysql://{credentials_part}{mysql_host}/{db_name}"

else:
config['db_url'] = config.get('db_url', DEFAULT_DB_URL)

return config


Expand Down