mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2024-11-10 11:49:05 +03:00
backend: port Replit to GGUF
This commit is contained in:
parent
7c67262a13
commit
17fc9e3e58
@ -97,11 +97,6 @@ foreach(BUILD_VARIANT IN LISTS BUILD_VARIANTS)
|
||||
LLAMA_VERSIONS=>=3 LLAMA_DATE=999999)
|
||||
prepare_target(llamamodel-mainline llama-mainline)
|
||||
|
||||
add_library(replit-mainline-${BUILD_VARIANT} SHARED
|
||||
replit.cpp utils.h utils.cpp llmodel_shared.cpp llmodel_shared.h)
|
||||
target_compile_definitions(replit-mainline-${BUILD_VARIANT} PRIVATE LLAMA_VERSIONS=>=3 LLAMA_DATE=999999)
|
||||
prepare_target(replit-mainline llama-mainline)
|
||||
|
||||
if (NOT LLAMA_METAL)
|
||||
# FIXME: These need to be forward ported to latest ggml
|
||||
# add_library(gptj-${BUILD_VARIANT} SHARED
|
||||
|
@ -97,6 +97,116 @@ enum mpt_token_type {
|
||||
MPT_TOKEN_TYPE_CONTROL = 3,
|
||||
};
|
||||
|
||||
using replit_piece_t = std::pair<std::size_t, float>;
|
||||
using replit_piece_map_t = std::unordered_map<std::string, replit_piece_t>;
|
||||
|
||||
static const std::string replit_ws_symbol = "\342\226\201";
|
||||
|
||||
struct mpt_vocab {
|
||||
bool is_replit = false;
|
||||
gpt_vocab raw;
|
||||
replit_piece_map_t piece_map;
|
||||
std::vector<std::string> vocab;
|
||||
|
||||
const char * end_of_text() const {
|
||||
return is_replit ? "<|endoftext|>" : "<|im_end|>";
|
||||
}
|
||||
};
|
||||
|
||||
std::pair<std::vector<LLModel::Token>, float> encode_word(const std::string & word, const replit_piece_map_t & model) {
|
||||
std::vector<int> best_segmentations_starts(word.length() + 1, -1);
|
||||
best_segmentations_starts[0] = 0;
|
||||
|
||||
std::vector<float> best_segmentations_scores(word.length() + 1, -std::numeric_limits<float>::infinity());
|
||||
best_segmentations_scores[0] = 1.0;
|
||||
|
||||
for (size_t start_idx = 0; start_idx < word.length(); ++start_idx) {
|
||||
float best_score_at_start = best_segmentations_scores[start_idx];
|
||||
for (size_t end_idx = start_idx + 1; end_idx <= word.length(); ++end_idx) {
|
||||
std::string token = word.substr(start_idx, end_idx - start_idx);
|
||||
if (model.count(token) && best_score_at_start != -std::numeric_limits<float>::infinity()) {
|
||||
float token_score = model.at(token).second;
|
||||
float score = token_score + best_score_at_start;
|
||||
if (best_segmentations_scores[end_idx] == -std::numeric_limits<float>::infinity() ||
|
||||
best_segmentations_scores[end_idx] > score) {
|
||||
best_segmentations_starts[end_idx] = start_idx;
|
||||
best_segmentations_scores[end_idx] = score;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (best_segmentations_scores.back() == -std::numeric_limits<float>::infinity()) {
|
||||
return std::make_pair(std::vector<LLModel::Token>{0}, 0.0f);
|
||||
}
|
||||
|
||||
float score = best_segmentations_scores.back();
|
||||
int start = best_segmentations_starts.back();
|
||||
int end = word.length();
|
||||
std::vector<LLModel::Token> tokens;
|
||||
while (start != 0) {
|
||||
const auto token_id = model.at(word.substr(start, end - start)).first;
|
||||
tokens.insert(tokens.begin(), token_id);
|
||||
int next_start = best_segmentations_starts[start];
|
||||
end = start;
|
||||
start = next_start;
|
||||
}
|
||||
const auto token_id = model.at(word.substr(start, end - start)).first;
|
||||
tokens.insert(tokens.begin(), token_id);
|
||||
return std::make_pair(tokens, score);
|
||||
}
|
||||
|
||||
bool replit_tokenizer_load(mpt_vocab & tokenizer, gguf_context * ggufctx, int tokens_keyidx, int max_vocab_size) {
|
||||
int scores_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.scores");
|
||||
if (scores_keyidx == -1) {
|
||||
fprintf(stderr, "%s: llama token scores not found!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
const auto *scores = reinterpret_cast<const float *>(gguf_get_arr_data(ggufctx, scores_keyidx));
|
||||
|
||||
for (LLModel::Token i = 0; i < max_vocab_size; i++) {
|
||||
std::string word = gguf_get_arr_str(ggufctx, tokens_keyidx, i);
|
||||
tokenizer.piece_map[word] = std::make_pair(i, -scores[i]);
|
||||
tokenizer.raw.id_to_token[i] = word;
|
||||
tokenizer.raw.token_to_id[word] = i;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string replace_all(const std::string & str, // where to work
|
||||
const std::string & find, // substitute 'find'
|
||||
const std::string & replace // by 'replace'
|
||||
) {
|
||||
std::string result;
|
||||
size_t find_len = find.size();
|
||||
size_t pos, from = 0;
|
||||
while (std::string::npos != (pos = str.find(find, from))) {
|
||||
result.append(str, from, pos - from);
|
||||
result.append(replace);
|
||||
from = pos + find_len;
|
||||
}
|
||||
result.append(str, from, std::string::npos);
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<LLModel::Token> replit_tokenizer_tokenize(mpt_vocab & tokenizer, const std::string & text) {
|
||||
std::vector<LLModel::Token> tokens;
|
||||
auto normalized_text = replace_all(text, " ", replit_ws_symbol);
|
||||
auto tokenized = encode_word(normalized_text, tokenizer.piece_map);
|
||||
|
||||
return tokenized.first;
|
||||
}
|
||||
|
||||
std::string replit_tokenizer_detokenize(mpt_vocab & tokenizer, const std::vector<LLModel::Token> & tokens) {
|
||||
std::string text;
|
||||
for (auto token : tokens) {
|
||||
text += tokenizer.raw.id_to_token[token];
|
||||
}
|
||||
return replace_all(text, replit_ws_symbol, " ");
|
||||
}
|
||||
|
||||
|
||||
static bool kv_cache_init(
|
||||
const struct mpt_hparams & hparams,
|
||||
struct llm_kv_cache & cache,
|
||||
@ -130,7 +240,7 @@ static bool kv_cache_init(
|
||||
|
||||
// load the model's weights from a file path. if mem_req ptr is passed the model is
|
||||
// only partially parsed to estimate required memory
|
||||
bool mpt_model_load(const std::string &fname, mpt_model & model, gpt_vocab & vocab, size_t * mem_req) {
|
||||
bool mpt_model_load(const std::string &fname, mpt_model & model, mpt_vocab & vocab, size_t * mem_req) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
if (mem_req != nullptr) {
|
||||
*mem_req = 0;
|
||||
@ -245,53 +355,56 @@ bool mpt_model_load(const std::string &fname, mpt_model & model, gpt_vocab & voc
|
||||
{
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
int keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.model");
|
||||
int tokens_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.tokens");
|
||||
if (tokens_keyidx == -1) {
|
||||
fprintf(stderr, "%s: tokenizer vocab not found!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
int keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.model");
|
||||
if (keyidx == -1) {
|
||||
fprintf(stderr, "%s: tokenizer model not found!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
// TODO: Replit (llama tokenizer)
|
||||
if (strcmp(gguf_get_val_str(ggufctx, keyidx), "gpt2") != 0) {
|
||||
fprintf(stderr, "%s: tokenizer model not supported!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
int tokens_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.tokens");
|
||||
if (tokens_keyidx == -1) {
|
||||
fprintf(stderr, "%s: gpt2 tokenizer vocab not found!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
int toktypes_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.token_type");
|
||||
if (toktypes_keyidx == -1) {
|
||||
fprintf(stderr, "%s: gpt2 token types not found!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
std::string tokenizer_model(gguf_get_val_str(ggufctx, keyidx));
|
||||
|
||||
hparams.n_vocab = gguf_get_arr_n(ggufctx, tokens_keyidx);
|
||||
printf("%s: gpt2 tokenizer vocab = %d\n", __func__, int(hparams.n_vocab));
|
||||
printf("%s: %s tokenizer vocab = %d\n", __func__, tokenizer_model.c_str(), int(hparams.n_vocab));
|
||||
|
||||
const auto *toktypes = reinterpret_cast<const uint32_t *>(gguf_get_arr_data(ggufctx, toktypes_keyidx));
|
||||
|
||||
for (int i = 0; i < hparams.n_vocab; i++) {
|
||||
std::string word = gguf_get_arr_str(ggufctx, tokens_keyidx, i);
|
||||
|
||||
bool special = false;
|
||||
if (toktypes[i] == MPT_TOKEN_TYPE_CONTROL) {
|
||||
special = true;
|
||||
} else if (toktypes[i] != MPT_TOKEN_TYPE_NORMAL) {
|
||||
fprintf(stderr, "%s: unknown token type: %d\n", __func__, int(toktypes[i]));
|
||||
if (tokenizer_model == "llama") { // Replit
|
||||
vocab.is_replit = true;
|
||||
if (!replit_tokenizer_load(vocab, ggufctx, tokens_keyidx, hparams.n_vocab)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
|
||||
if (special) {
|
||||
vocab.add_special_token(word);
|
||||
} else if (tokenizer_model == "gpt2") {
|
||||
int toktypes_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.token_type");
|
||||
if (toktypes_keyidx == -1) {
|
||||
fprintf(stderr, "%s: gpt2 token types not found!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
const auto *toktypes = reinterpret_cast<const uint32_t *>(gguf_get_arr_data(ggufctx, toktypes_keyidx));
|
||||
|
||||
for (int i = 0; i < hparams.n_vocab; i++) {
|
||||
std::string word = gguf_get_arr_str(ggufctx, tokens_keyidx, i);
|
||||
|
||||
bool special = false;
|
||||
if (toktypes[i] == MPT_TOKEN_TYPE_CONTROL) {
|
||||
special = true;
|
||||
} else if (toktypes[i] != MPT_TOKEN_TYPE_NORMAL) {
|
||||
fprintf(stderr, "%s: unknown token type: %d\n", __func__, int(toktypes[i]));
|
||||
return false;
|
||||
}
|
||||
|
||||
vocab.raw.token_to_id[word] = i;
|
||||
vocab.raw.id_to_token[i] = word;
|
||||
|
||||
if (special) {
|
||||
vocab.raw.add_special_token(word);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fprintf(stderr, "%s: tokenizer model not supported!\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -675,7 +788,7 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr
|
||||
struct MPTPrivate {
|
||||
const std::string modelPath;
|
||||
bool modelLoaded;
|
||||
gpt_vocab vocab;
|
||||
mpt_vocab vocab;
|
||||
mpt_model *model = nullptr;
|
||||
int64_t n_threads = 0;
|
||||
size_t mem_per_token = 0;
|
||||
@ -692,7 +805,7 @@ MPT::MPT()
|
||||
|
||||
size_t MPT::requiredMem(const std::string &modelPath) {
|
||||
mpt_model dummy_model;
|
||||
gpt_vocab dummy_vocab;
|
||||
mpt_vocab dummy_vocab;
|
||||
size_t mem_req;
|
||||
mpt_model_load(modelPath, dummy_model, dummy_vocab, &mem_req);
|
||||
return mem_req;
|
||||
@ -710,7 +823,8 @@ bool MPT::loadModel(const std::string &modelPath) {
|
||||
|
||||
d_ptr->n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
|
||||
d_ptr->modelLoaded = true;
|
||||
d_ptr->has_end_of_text = d_ptr->vocab.token_to_id.find("<|im_end|>") != d_ptr->vocab.token_to_id.end();
|
||||
const auto & vocab = d_ptr->vocab;
|
||||
d_ptr->has_end_of_text = vocab.raw.token_to_id.find(vocab.end_of_text()) != vocab.raw.token_to_id.end();
|
||||
fflush(stdout);
|
||||
return true;
|
||||
}
|
||||
@ -751,12 +865,18 @@ size_t MPT::restoreState(const uint8_t *src)
|
||||
|
||||
std::vector<LLModel::Token> MPT::tokenize(PromptContext &, const std::string &str) const
|
||||
{
|
||||
return ::gpt_tokenize(d_ptr->vocab, str);
|
||||
if (d_ptr->vocab.is_replit) {
|
||||
return replit_tokenizer_tokenize(d_ptr->vocab, str);
|
||||
}
|
||||
return ::gpt_tokenize(d_ptr->vocab.raw, str);
|
||||
}
|
||||
|
||||
std::string MPT::tokenToString(Token id) const
|
||||
{
|
||||
return d_ptr->vocab.id_to_token[id];
|
||||
if (d_ptr->vocab.is_replit) {
|
||||
return replit_tokenizer_detokenize(d_ptr->vocab, {id});
|
||||
}
|
||||
return d_ptr->vocab.raw.id_to_token[id];
|
||||
}
|
||||
|
||||
LLModel::Token MPT::sampleToken(PromptContext &promptCtx) const
|
||||
@ -791,7 +911,10 @@ int32_t MPT::contextLength() const
|
||||
|
||||
const std::vector<LLModel::Token> &MPT::endTokens() const
|
||||
{
|
||||
static const std::vector<LLModel::Token> fres = {0, d_ptr->vocab.token_to_id["<|im_end|>"]};
|
||||
static std::vector<LLModel::Token> fres;
|
||||
if (fres.empty()) {
|
||||
fres = {0, d_ptr->vocab.raw.token_to_id[d_ptr->vocab.end_of_text()]};
|
||||
}
|
||||
return fres;
|
||||
}
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,113 +0,0 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import struct
|
||||
import json
|
||||
import numpy as np
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
import sentencepiece.sentencepiece_model_pb2 as model
|
||||
|
||||
if len(sys.argv) < 3:
|
||||
print("Usage: convert-h5-to-ggml.py dir-model [use-f32]\n")
|
||||
print(" ftype == 0 -> float32")
|
||||
print(" ftype == 1 -> float16")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
# output in the same directory as the model
|
||||
dir_model = sys.argv[1]
|
||||
fname_out = sys.argv[1] + "/ggml-replit-code-v1-3b.bin"
|
||||
|
||||
|
||||
with open(dir_model + "/config.json", "r", encoding="utf-8") as f:
|
||||
hparams = json.load(f)
|
||||
|
||||
sp_proto = model.ModelProto()
|
||||
sp_proto.ParseFromString(open(Path(sys.argv[1]) / "spiece.model", "rb").read())
|
||||
|
||||
|
||||
# possible data types
|
||||
# ftype == 0 -> float32
|
||||
# ftype == 1 -> float16
|
||||
#
|
||||
# map from ftype to string
|
||||
ftype_str = ["f32", "f16"]
|
||||
|
||||
ftype = 1
|
||||
if len(sys.argv) > 2:
|
||||
ftype = int(sys.argv[2])
|
||||
if ftype < 0 or ftype > 1:
|
||||
print("Invalid ftype: " + str(ftype))
|
||||
sys.exit(1)
|
||||
fname_out = sys.argv[1] + "/ggml-replit-code-v1-3b-" + ftype_str[ftype] + ".bin"
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
dir_model, low_cpu_mem_usage=True, trust_remote_code=True
|
||||
)
|
||||
# print (model)
|
||||
|
||||
# print(tokenizer.encode('I believe the meaning of life is'))
|
||||
|
||||
list_vars = model.state_dict()
|
||||
for name in list_vars.keys():
|
||||
print(name, list_vars[name].shape, list_vars[name].dtype)
|
||||
|
||||
fout = open(fname_out, "wb")
|
||||
|
||||
print(hparams)
|
||||
|
||||
fout.write(struct.pack("i", 0x7265706c)) # magic: repl in hex
|
||||
fout.write(struct.pack("i", hparams["vocab_size"]))
|
||||
fout.write(struct.pack("i", hparams["max_seq_len"]))
|
||||
fout.write(struct.pack("i", hparams["d_model"]))
|
||||
fout.write(struct.pack("i", hparams["n_heads"]))
|
||||
fout.write(struct.pack("i", hparams["n_layers"]))
|
||||
fout.write(struct.pack("i", ftype))
|
||||
|
||||
|
||||
# TODO: temporary hack to not deal with implementing the tokenizer
|
||||
for piece in sp_proto.pieces:
|
||||
encoded_piece = piece.piece.encode("utf-8")
|
||||
fout.write(struct.pack("i", len(encoded_piece)))
|
||||
fout.write(encoded_piece)
|
||||
fout.write(struct.pack("f", piece.score))
|
||||
|
||||
|
||||
for name in list_vars.keys():
|
||||
data = list_vars[name].squeeze().numpy()
|
||||
print("Processing variable: " + name + " with shape: ", data.shape)
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||
ftype_cur = 0
|
||||
if ftype != 0:
|
||||
if name[-7:] == ".weight" and n_dims == 2:
|
||||
print(" Converting to float16")
|
||||
data = data.astype(np.float16)
|
||||
ftype_cur = 1
|
||||
else:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
else:
|
||||
if data.dtype != np.float32:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
|
||||
# header
|
||||
str = name.encode("utf-8")
|
||||
fout.write(struct.pack("iii", n_dims, len(str), ftype_cur))
|
||||
for i in range(n_dims):
|
||||
fout.write(struct.pack("i", data.shape[n_dims - 1 - i]))
|
||||
fout.write(str)
|
||||
|
||||
# data
|
||||
data.tofile(fout)
|
||||
|
||||
fout.close()
|
||||
|
||||
print("Done. Output file: " + fname_out)
|
||||
print("")
|
142
gpt4all-backend/scripts/convert_replit_hf_to_gguf.py
Normal file
142
gpt4all-backend/scripts/convert_replit_hf_to_gguf.py
Normal file
@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import gguf
|
||||
import numpy as np
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
if not 2 <= len(sys.argv) < 4:
|
||||
print("Usage: {} dir-model [ftype]\n".format(os.path.basename(__file__)))
|
||||
print(" ftype == 0 -> float32")
|
||||
print(" ftype == 1 -> float16")
|
||||
sys.exit(1)
|
||||
|
||||
# output in the same directory as the model
|
||||
dir_model = Path(sys.argv[1])
|
||||
|
||||
# possible data types
|
||||
# ftype == 0 -> float32
|
||||
# ftype == 1 -> float16
|
||||
#
|
||||
# map from ftype to string
|
||||
ftype_str = ["f32", "f16"]
|
||||
ftype = 1
|
||||
if len(sys.argv) > 2:
|
||||
ftype = int(sys.argv[2])
|
||||
if ftype < 0 or ftype > 1:
|
||||
print("Invalid ftype: " + str(ftype))
|
||||
sys.exit(1)
|
||||
|
||||
fname_out = dir_model / ("ggml-replit-code-v1-3b-" + ftype_str[ftype] + ".gguf")
|
||||
|
||||
|
||||
ARCH = gguf.MODEL_ARCH.MPT
|
||||
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
||||
|
||||
print("gguf: get model metadata")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(dir_model, low_cpu_mem_usage=True)
|
||||
config = model.config
|
||||
#print(model)
|
||||
|
||||
block_count = config.n_layers
|
||||
gguf_writer.add_name("Replit")
|
||||
gguf_writer.add_context_length(config.max_seq_len)
|
||||
gguf_writer.add_embedding_length(config.d_model)
|
||||
gguf_writer.add_block_count(block_count)
|
||||
gguf_writer.add_head_count(config.n_heads)
|
||||
gguf_writer.add_max_alibi_bias(config.attn_config.alibi_bias_max)
|
||||
gguf_writer.add_layer_norm_eps(config.layer_norm_epsilon)
|
||||
gguf_writer.add_file_type(ftype)
|
||||
|
||||
clip_qkv = config.attn_config.clip_qkv
|
||||
if clip_qkv is not None:
|
||||
gguf_writer.add_clamp_kqv(clip_qkv)
|
||||
|
||||
print("gguf: get sentencepiece tokenizer vocab")
|
||||
|
||||
tokenizer = SentencePieceProcessor(str(dir_model / "spiece.model"))
|
||||
#print(tokenizer.encode('I believe the meaning of life is'))
|
||||
|
||||
tokens: list[bytearray] = []
|
||||
scores: list[float] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
for i in range(tokenizer.vocab_size()):
|
||||
tokens.append(tokenizer.id_to_piece(i).encode('utf-8'))
|
||||
scores.append(tokenizer.get_score(i))
|
||||
|
||||
toktype = gguf.TokenType.NORMAL
|
||||
if tokenizer.is_unknown(i):
|
||||
toktype = gguf.TokenType.UNKNOWN
|
||||
elif tokenizer.is_control(i):
|
||||
toktype = gguf.TokenType.CONTROL
|
||||
elif tokenizer.is_unused(i):
|
||||
toktype = gguf.TokenType.UNUSED
|
||||
elif tokenizer.is_byte(i):
|
||||
toktype = gguf.TokenType.BYTE
|
||||
|
||||
toktypes.append(toktype)
|
||||
|
||||
gguf_writer.add_tokenizer_model("llama") # sentencepiece
|
||||
gguf_writer.add_token_list(tokens)
|
||||
gguf_writer.add_token_scores(scores)
|
||||
gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(gguf_writer)
|
||||
|
||||
print("gguf: get tensor metadata")
|
||||
|
||||
tensor_map = gguf.get_tensor_name_map(ARCH, block_count)
|
||||
|
||||
list_vars = model.state_dict()
|
||||
for name in list_vars.keys():
|
||||
print(name, list_vars[name].shape, list_vars[name].dtype)
|
||||
|
||||
print(config)
|
||||
|
||||
for name in list_vars.keys():
|
||||
data = list_vars[name].squeeze().numpy()
|
||||
print("Processing variable:", name, "with shape:", data.shape)
|
||||
|
||||
n_dims = len(data.shape)
|
||||
|
||||
# ftype == 0 -> float32, ftype == 1 -> float16
|
||||
ftype_cur = 0
|
||||
if ftype == 1 and name[-7:] == ".weight" and n_dims == 2:
|
||||
print(" Converting to float16")
|
||||
data = data.astype(np.float16)
|
||||
ftype_cur = 1
|
||||
elif ftype == 1 or data.dtype != np.float32:
|
||||
print(" Converting to float32")
|
||||
data = data.astype(np.float32)
|
||||
ftype_cur = 0
|
||||
|
||||
# map tensor names
|
||||
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
|
||||
if new_name is None:
|
||||
print("Can not map tensor '" + name + "'")
|
||||
sys.exit()
|
||||
|
||||
gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
print("gguf: write header")
|
||||
gguf_writer.write_header_to_file()
|
||||
print("gguf: write metadata")
|
||||
gguf_writer.write_kv_data_to_file()
|
||||
print("gguf: write tensors")
|
||||
gguf_writer.write_tensors_to_file()
|
||||
|
||||
gguf_writer.close()
|
||||
|
||||
print(f"gguf: model successfully exported to '{fname_out}'")
|
||||
print()
|
Loading…
Reference in New Issue
Block a user