From 020af05a8b1f4b4ef46373e6e61dcd32869fc1b1 Mon Sep 17 00:00:00 2001 From: Jerin Philip Date: Fri, 10 Jun 2022 11:29:01 +0000 Subject: [PATCH] Client side fixes for int8 no shift on ARM [python] --- bindings/python/bergamot.cpp | 5 +++-- bindings/python/cmds.py | 12 ++++++------ bindings/python/utils.py | 15 +++++++++++++++ 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/bindings/python/bergamot.cpp b/bindings/python/bergamot.cpp index 5e9e830..30ccd83 100644 --- a/bindings/python/bergamot.cpp +++ b/bindings/python/bergamot.cpp @@ -38,8 +38,9 @@ class ServicePyAdapter { marian::setThrowExceptionOnAbort(true); } - std::shared_ptr<_Model> modelFromConfig(const std::string &config) { - auto parsedConfig = marian::bergamot::parseOptionsFromString(config); + std::shared_ptr<_Model> modelFromConfig(const std::string &config, bool validate = true, + const std::string &base_dir = "") { + auto parsedConfig = marian::bergamot::parseOptionsFromString(config, validate, base_dir); return service_.createCompatibleModel(parsedConfig); } diff --git a/bindings/python/cmds.py b/bindings/python/cmds.py index 5949ada..6260266 100644 --- a/bindings/python/cmds.py +++ b/bindings/python/cmds.py @@ -3,6 +3,7 @@ import sys from collections import Counter, defaultdict from . import REPOSITORY, ResponseOptions, Service, ServiceConfig, VectorString +from .utils import patched_platform_from_config_path CMDS = {} @@ -74,12 +75,11 @@ class Translate: config = ServiceConfig(numWorkers=args.num_workers, logLevel=args.log_level) service = Service(config) - models = [ - service.modelFromConfigPath( - REPOSITORY.modelConfigPath(args.repository, model) - ) - for model in args.model - ] + models = [] + for model in args.model: + configPath = REPOSITORY.modelConfigPath(args.repository, model) + config = patched_platform_from_config_path(configPath) + models.append(service.modelFromConfig(config, True, configPath)) # Configure a few options which require how a Response is constructed options = ResponseOptions( diff --git a/bindings/python/utils.py b/bindings/python/utils.py index 3164c17..5e4f98e 100644 --- a/bindings/python/utils.py +++ b/bindings/python/utils.py @@ -1,4 +1,5 @@ import os +import platform import requests import yaml @@ -50,3 +51,17 @@ def patch_marian_for_bergamot( # Write-out. with open(bergamot_config_path, "w") as output_file: print(yaml.dump(data, sort_keys=False), file=output_file) + + +def patched_platform_from_config_path(bergamot_config_path: PathLike) -> str: + data = None + with open(bergamot_config_path) as bergamot_config_file: + data = yaml.load(bergamot_config_file, Loader=yaml.FullLoader) + if "int8" in data["gemm-precision"]: + processor = platform.processor() + if processor in ["arm64", "aarch64"]: + # Remove shift, because the path available only on intel. + data["gemm-precision"] = data["gemm-precision"].replace("shift", "") + data["gemm-precision"] = data["gemm-precision"].replace("All", "") + + return yaml.dump(data, sort_keys=False)