Client side fixes for int8 no shift on ARM [python]

This commit is contained in:
Jerin Philip 2022-06-10 11:29:01 +00:00
parent 6f82203e43
commit 020af05a8b
3 changed files with 24 additions and 8 deletions

View File

@ -38,8 +38,9 @@ class ServicePyAdapter {
marian::setThrowExceptionOnAbort(true);
}
std::shared_ptr<_Model> modelFromConfig(const std::string &config) {
auto parsedConfig = marian::bergamot::parseOptionsFromString(config);
std::shared_ptr<_Model> modelFromConfig(const std::string &config, bool validate = true,
const std::string &base_dir = "") {
auto parsedConfig = marian::bergamot::parseOptionsFromString(config, validate, base_dir);
return service_.createCompatibleModel(parsedConfig);
}

View File

@ -3,6 +3,7 @@ import sys
from collections import Counter, defaultdict
from . import REPOSITORY, ResponseOptions, Service, ServiceConfig, VectorString
from .utils import patched_platform_from_config_path
CMDS = {}
@ -74,12 +75,11 @@ class Translate:
config = ServiceConfig(numWorkers=args.num_workers, logLevel=args.log_level)
service = Service(config)
models = [
service.modelFromConfigPath(
REPOSITORY.modelConfigPath(args.repository, model)
)
for model in args.model
]
models = []
for model in args.model:
configPath = REPOSITORY.modelConfigPath(args.repository, model)
config = patched_platform_from_config_path(configPath)
models.append(service.modelFromConfig(config, True, configPath))
# Configure a few options which require how a Response is constructed
options = ResponseOptions(

View File

@ -1,4 +1,5 @@
import os
import platform
import requests
import yaml
@ -50,3 +51,17 @@ def patch_marian_for_bergamot(
# Write-out.
with open(bergamot_config_path, "w") as output_file:
print(yaml.dump(data, sort_keys=False), file=output_file)
def patched_platform_from_config_path(bergamot_config_path: PathLike) -> str:
data = None
with open(bergamot_config_path) as bergamot_config_file:
data = yaml.load(bergamot_config_file, Loader=yaml.FullLoader)
if "int8" in data["gemm-precision"]:
processor = platform.processor()
if processor in ["arm64", "aarch64"]:
# Remove shift, because the path available only on intel.
data["gemm-precision"] = data["gemm-precision"].replace("shift", "")
data["gemm-precision"] = data["gemm-precision"].replace("All", "")
return yaml.dump(data, sort_keys=False)