Client side fixes for int8 no shift on ARM [python]

This commit is contained in:
Jerin Philip 2022-06-10 11:29:01 +00:00
parent 6f82203e43
commit 020af05a8b
3 changed files with 24 additions and 8 deletions

View File

@ -38,8 +38,9 @@ class ServicePyAdapter {
marian::setThrowExceptionOnAbort(true); marian::setThrowExceptionOnAbort(true);
} }
std::shared_ptr<_Model> modelFromConfig(const std::string &config) { std::shared_ptr<_Model> modelFromConfig(const std::string &config, bool validate = true,
auto parsedConfig = marian::bergamot::parseOptionsFromString(config); const std::string &base_dir = "") {
auto parsedConfig = marian::bergamot::parseOptionsFromString(config, validate, base_dir);
return service_.createCompatibleModel(parsedConfig); return service_.createCompatibleModel(parsedConfig);
} }

View File

@ -3,6 +3,7 @@ import sys
from collections import Counter, defaultdict from collections import Counter, defaultdict
from . import REPOSITORY, ResponseOptions, Service, ServiceConfig, VectorString from . import REPOSITORY, ResponseOptions, Service, ServiceConfig, VectorString
from .utils import patched_platform_from_config_path
CMDS = {} CMDS = {}
@ -74,12 +75,11 @@ class Translate:
config = ServiceConfig(numWorkers=args.num_workers, logLevel=args.log_level) config = ServiceConfig(numWorkers=args.num_workers, logLevel=args.log_level)
service = Service(config) service = Service(config)
models = [ models = []
service.modelFromConfigPath( for model in args.model:
REPOSITORY.modelConfigPath(args.repository, model) configPath = REPOSITORY.modelConfigPath(args.repository, model)
) config = patched_platform_from_config_path(configPath)
for model in args.model models.append(service.modelFromConfig(config, True, configPath))
]
# Configure a few options which require how a Response is constructed # Configure a few options which require how a Response is constructed
options = ResponseOptions( options = ResponseOptions(

View File

@ -1,4 +1,5 @@
import os import os
import platform
import requests import requests
import yaml import yaml
@ -50,3 +51,17 @@ def patch_marian_for_bergamot(
# Write-out. # Write-out.
with open(bergamot_config_path, "w") as output_file: with open(bergamot_config_path, "w") as output_file:
print(yaml.dump(data, sort_keys=False), file=output_file) print(yaml.dump(data, sort_keys=False), file=output_file)
def patched_platform_from_config_path(bergamot_config_path: PathLike) -> str:
data = None
with open(bergamot_config_path) as bergamot_config_file:
data = yaml.load(bergamot_config_file, Loader=yaml.FullLoader)
if "int8" in data["gemm-precision"]:
processor = platform.processor()
if processor in ["arm64", "aarch64"]:
# Remove shift, because the path available only on intel.
data["gemm-precision"] = data["gemm-precision"].replace("shift", "")
data["gemm-precision"] = data["gemm-precision"].replace("All", "")
return yaml.dump(data, sort_keys=False)