Skip to content
Snippets Groups Projects
import_sql.py 8.45 KiB
#!/usr/bin/env python3
"""Imports CSV files into an Oddo database using SQL.
"""
import csv
import logging
import sys
from typing import List, Tuple

from .importing import (
    add_importing_file_parsing,
    extract_model_table_from_parsed,
)
from .postgres import postgres_apply, postgres_connect_parser

_logger = logging.getLogger(__name__)

__version__ = "1.0.0"
__date__ = "2020-06-08"
__updated__ = "2020-06-08"


def _execute(cursor, query, parameters):
    _logger.debug("Query %s, parameters %s", query, parameters)
    cursor.execute(query, parameters)


def sql_import(
    connection, table_filenames: List[Tuple[str, str]], delimiter: str
) -> None:
    # quick check to fail quicker
    with connection.cursor() as cur:
        for element in table_filenames:
            model = element["model"]
            csv_file = element["filename"]
            if "table" in element and element["table"]:
                table = element["table"]
            else:
                # do Odoo model to table conversion
                # (same as _build_model_attributes)
                table = model.replace(".", "_")
            _logger.info(
                "Importing - %s in %s (table %s)", csv_file, model, table
            )
            with open(csv_file, "r") as file:
                reader = csv.DictReader(
                    file, delimiter=delimiter, quotechar='"'
                )
                headers_check = False
                created = written = 0
                for row in reader:
                    # for each file read the headers
                    if not headers_check:
                        _logger.debug("Checking headers")
                        many2x_columns = list()
                        for column_name in row.keys():
                            if column_name.endswith(
                                "/id"
                            ) or column_name.endswith(":id"):
                                many2x_columns.append(column_name)
                        if many2x_columns:
                            msg = (
                                "Importing many2many or many2one is not "
                                "supported: %s" % (", ".join(many2x_columns))
                            )
                            _logger.fatal(msg)
                            raise Exception(msg)
                        headers_check = True
                    if "id" in row:
                        xmlid = row.pop("id")
                        if "." in xmlid:
                            module, name = xmlid.split(".", 1)
                        else:
                            module = ""
                            name = xmlid
                        # look up if the given id exists
                        _execute(
                            cur,
                            "SELECT id, res_id "
                            "FROM ir_model_data where model=%s "
                            "AND module=%s AND name=%s",
                            (model, module, name),
                        )
                        imd_data = cur.fetchone()
                        if imd_data:
                            imd_id, res_id = imd_data
                            _logger.debug(
                                "Found id %d for xmlid %s", res_id, xmlid
                            )
                            # if the id exists, do an update
                            query = (
                                "UPDATE {table} "
                                "SET {data_placeholders} "
                                "WHERE id=%s"
                            ).format(
                                table=table,
                                data_placeholders=",".join(
                                    ["write_date=(now() at time zone 'UTC')"]
                                    + ["{}=%s".format(col) for col in row]
                                ),
                            )
                            _execute(cur, query, list(row.values()) + [res_id])
                            # update date_update in ir_model_data
                            _execute(
                                cur,
                                "UPDATE ir_model_data "
                                "SET date_update=(now() at time zone 'UTC') "
                                "WHERE id=%s",
                                (imd_id,),
                            )
                            written += 1
                        else:
                            # otherwise do an insert and update the columns
                            query = (
                                "INSERT INTO {table} ({columns}) "
                                "VALUES ({data_placeholders}) "
                                "RETURNING id"
                            ).format(
                                table=table,
                                columns=",".join(
                                    ["create_date", "write_date"]
                                    + list(row.keys())
                                ),
                                data_placeholders=",".join(
                                    ["(now() at time zone 'UTC')"] * 2
                                    + ["%s"] * len(row)
                                ),
                            )
                            _execute(cur, query, tuple(row.values()))
                            # and insert into ir_model_data
                            db_id = cur.fetchone()[0]
                            _execute(
                                cur,
                                "INSERT INTO ir_model_data (module, name, "
                                "model, res_id, date_init, date_update) "
                                "VALUES (%s, %s, %s, %s, "
                                "(now() at time zone 'UTC'), "
                                "(now() at time zone 'UTC'))",
                                (module, name, model, db_id),
                            )
                            created += 1
                    else:
                        # no xml id, just insert lines into the table
                        query = (
                            "INSERT INTO {table} ({columns}) "
                            "VALUES ({data_placeholders})"
                        ).format(
                            table=table,
                            columns=",".join(
                                ["create_date", "write_date"]
                                + list(row.keys())
                            ),
                            data_placeholders=",".join(
                                ["(now() at time zone 'UTC')"] * 2
                                + ["%s"] * len(row)
                            ),
                        )
                        _execute(cur, query, tuple(row.values()))
                        created += 1
                    if (created + written) % 100 == 0:
                        _logger.info(
                            "%s progress: created %d, wrote %d",
                            csv_file,
                            created,
                            written,
                        )
                _logger.info(
                    "%s: created %d, wrote %d", csv_file, created, written
                )

        connection.commit()
        _logger.info("Commited.")


def main(argv=None):  # IGNORE:C0111
    """Parse arguments and launch conversion
    """
    program_version = __version__
    program_build_date = str(__updated__)
    program_version_message = "%%(prog)s %s (%s)" % (
        program_version,
        program_build_date,
    )
    program_shortdesc = __doc__.split(".")[0]
    program_license = """%s

  Created by Vincent Hatakeyama on %s.
  Copyright 2020 XCG Consulting. All rights reserved.

  Licensed under the MIT License

  Distributed on an "AS IS" basis without warranties
  or conditions of any kind, either express or implied.

USAGE
""" % (
        program_shortdesc,
        str(__date__),
    )
    parser = postgres_connect_parser(program_license, program_version_message)
    add_importing_file_parsing(parser)
    parser.add_argument(
        "--delimiter", help="CSV delimiter [default: %(default)s]", default=","
    )

    nmspc = parser.parse_args(argv)
    conn = postgres_apply(nmspc)

    sql_import(
        connection=conn,
        table_filenames=extract_model_table_from_parsed(nmspc),
        delimiter=nmspc.delimiter,
    )
    conn.close()
    return 0


if __name__ == "__main__":
    return_code = main(sys.argv[1:])
    if return_code:
        exit(return_code)