mirror of
https://github.com/browsermt/bergamot-translator.git
synced 2024-08-15 08:30:46 +03:00
Client side fixes for int8 no shift on ARM [python]
This commit is contained in:
parent
6f82203e43
commit
020af05a8b
@ -38,8 +38,9 @@ class ServicePyAdapter {
|
|||||||
marian::setThrowExceptionOnAbort(true);
|
marian::setThrowExceptionOnAbort(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<_Model> modelFromConfig(const std::string &config) {
|
std::shared_ptr<_Model> modelFromConfig(const std::string &config, bool validate = true,
|
||||||
auto parsedConfig = marian::bergamot::parseOptionsFromString(config);
|
const std::string &base_dir = "") {
|
||||||
|
auto parsedConfig = marian::bergamot::parseOptionsFromString(config, validate, base_dir);
|
||||||
return service_.createCompatibleModel(parsedConfig);
|
return service_.createCompatibleModel(parsedConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3,6 +3,7 @@ import sys
|
|||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
|
|
||||||
from . import REPOSITORY, ResponseOptions, Service, ServiceConfig, VectorString
|
from . import REPOSITORY, ResponseOptions, Service, ServiceConfig, VectorString
|
||||||
|
from .utils import patched_platform_from_config_path
|
||||||
|
|
||||||
CMDS = {}
|
CMDS = {}
|
||||||
|
|
||||||
@ -74,12 +75,11 @@ class Translate:
|
|||||||
config = ServiceConfig(numWorkers=args.num_workers, logLevel=args.log_level)
|
config = ServiceConfig(numWorkers=args.num_workers, logLevel=args.log_level)
|
||||||
service = Service(config)
|
service = Service(config)
|
||||||
|
|
||||||
models = [
|
models = []
|
||||||
service.modelFromConfigPath(
|
for model in args.model:
|
||||||
REPOSITORY.modelConfigPath(args.repository, model)
|
configPath = REPOSITORY.modelConfigPath(args.repository, model)
|
||||||
)
|
config = patched_platform_from_config_path(configPath)
|
||||||
for model in args.model
|
models.append(service.modelFromConfig(config, True, configPath))
|
||||||
]
|
|
||||||
|
|
||||||
# Configure a few options which require how a Response is constructed
|
# Configure a few options which require how a Response is constructed
|
||||||
options = ResponseOptions(
|
options = ResponseOptions(
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
import yaml
|
import yaml
|
||||||
@ -50,3 +51,17 @@ def patch_marian_for_bergamot(
|
|||||||
# Write-out.
|
# Write-out.
|
||||||
with open(bergamot_config_path, "w") as output_file:
|
with open(bergamot_config_path, "w") as output_file:
|
||||||
print(yaml.dump(data, sort_keys=False), file=output_file)
|
print(yaml.dump(data, sort_keys=False), file=output_file)
|
||||||
|
|
||||||
|
|
||||||
|
def patched_platform_from_config_path(bergamot_config_path: PathLike) -> str:
|
||||||
|
data = None
|
||||||
|
with open(bergamot_config_path) as bergamot_config_file:
|
||||||
|
data = yaml.load(bergamot_config_file, Loader=yaml.FullLoader)
|
||||||
|
if "int8" in data["gemm-precision"]:
|
||||||
|
processor = platform.processor()
|
||||||
|
if processor in ["arm64", "aarch64"]:
|
||||||
|
# Remove shift, because the path available only on intel.
|
||||||
|
data["gemm-precision"] = data["gemm-precision"].replace("shift", "")
|
||||||
|
data["gemm-precision"] = data["gemm-precision"].replace("All", "")
|
||||||
|
|
||||||
|
return yaml.dump(data, sort_keys=False)
|
||||||
|
Loading…
Reference in New Issue
Block a user