diff --git a/gpt4all-bindings/golang/binding.cpp b/gpt4all-bindings/golang/binding.cpp index 739d2524..dee592b5 100644 --- a/gpt4all-bindings/golang/binding.cpp +++ b/gpt4all-bindings/golang/binding.cpp @@ -2,14 +2,11 @@ #include "../../gpt4all-backend/llmodel.h" #include "../../gpt4all-backend/llama.cpp/llama.h" #include "../../gpt4all-backend/llmodel_c.cpp" -#include "../../gpt4all-backend/mpt.h" -#include "../../gpt4all-backend/mpt.cpp" -#include "../../gpt4all-backend/llamamodel.h" -#include "../../gpt4all-backend/gptj.h" #include "binding.h" #include #include +#include #include #include #include @@ -19,46 +16,24 @@ #include #include -void* load_mpt_model(const char *fname, int n_threads) { +void* load_gpt4all_model(const char *fname, int n_threads) { // load the model - auto gptj = llmodel_mpt_create(); - - llmodel_setThreadCount(gptj, n_threads); - if (!llmodel_loadModel(gptj, fname)) { + auto gptj4all = llmodel_model_create(fname); + if (gptj4all == NULL ){ + return nullptr; + } + llmodel_setThreadCount(gptj4all, n_threads); + if (!llmodel_loadModel(gptj4all, fname)) { return nullptr; } - return gptj; -} - -void* load_llama_model(const char *fname, int n_threads) { - // load the model - auto gptj = llmodel_llama_create(); - - llmodel_setThreadCount(gptj, n_threads); - if (!llmodel_loadModel(gptj, fname)) { - return nullptr; - } - - return gptj; -} - -void* load_gptj_model(const char *fname, int n_threads) { - // load the model - auto gptj = llmodel_gptj_create(); - - llmodel_setThreadCount(gptj, n_threads); - if (!llmodel_loadModel(gptj, fname)) { - return nullptr; - } - - return gptj; + return gptj4all; } std::string res = ""; void * mm; -void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, +void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, float top_p, float temp, int n_batch,float ctx_erase) { llmodel_model* model = (llmodel_model*) m; @@ -120,8 +95,8 @@ void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_la free(prompt_context); } -void gptj_free_model(void *state_ptr) { +void gpt4all_free_model(void *state_ptr) { llmodel_model* ctx = (llmodel_model*) state_ptr; - llmodel_llama_destroy(ctx); + llmodel_model_destroy(*ctx); } diff --git a/gpt4all-bindings/golang/binding.h b/gpt4all-bindings/golang/binding.h index 6b49a03e..2680e5a0 100644 --- a/gpt4all-bindings/golang/binding.h +++ b/gpt4all-bindings/golang/binding.h @@ -4,16 +4,12 @@ extern "C" { #include -void* load_mpt_model(const char *fname, int n_threads); +void* load_gpt4all_model(const char *fname, int n_threads); -void* load_llama_model(const char *fname, int n_threads); - -void* load_gptj_model(const char *fname, int n_threads); - -void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, +void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, float top_p, float temp, int n_batch,float ctx_erase); -void gptj_free_model(void *state_ptr); +void gpt4all_free_model(void *state_ptr); extern unsigned char getTokenCallback(void *, char *); diff --git a/gpt4all-bindings/golang/example/main.go b/gpt4all-bindings/golang/example/main.go index f3a103a7..2e692927 100644 --- a/gpt4all-bindings/golang/example/main.go +++ b/gpt4all-bindings/golang/example/main.go @@ -30,7 +30,7 @@ func main() { fmt.Printf("Parsing program arguments failed: %s", err) os.Exit(1) } - l, err := gpt4all.New(model, gpt4all.SetModelType(gpt4all.GPTJType), gpt4all.SetThreads(threads)) + l, err := gpt4all.New(model, gpt4all.SetThreads(threads)) if err != nil { fmt.Println("Loading the model failed:", err.Error()) os.Exit(1) diff --git a/gpt4all-bindings/golang/gpt4all.go b/gpt4all-bindings/golang/gpt4all.go index b0df6107..fa1efe22 100644 --- a/gpt4all-bindings/golang/gpt4all.go +++ b/gpt4all-bindings/golang/gpt4all.go @@ -5,12 +5,10 @@ package gpt4all // #cgo darwin LDFLAGS: -framework Accelerate // #cgo darwin CXXFLAGS: -std=c++17 // #cgo LDFLAGS: -lgpt4all -lm -lstdc++ -// void* load_mpt_model(const char *fname, int n_threads); -// void* load_llama_model(const char *fname, int n_threads); -// void* load_gptj_model(const char *fname, int n_threads); -// void gptj_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, +// void* load_gpt4all_model(const char *fname, int n_threads); +// void gpt4all_model_prompt( const char *prompt, void *m, char* result, int repeat_last_n, float repeat_penalty, int n_ctx, int tokens, int top_k, // float top_p, float temp, int n_batch,float ctx_erase); -// void gptj_free_model(void *state_ptr); +// void gpt4all_free_model(void *state_ptr); // extern unsigned char getTokenCallback(void *, char *); import "C" import ( @@ -28,16 +26,8 @@ type Model struct { func New(model string, opts ...ModelOption) (*Model, error) { ops := NewModelOptions(opts...) - var state unsafe.Pointer - switch ops.ModelType { - case LLaMAType: - state = C.load_llama_model(C.CString(model), C.int(ops.Threads)) - case GPTJType: - state = C.load_gptj_model(C.CString(model), C.int(ops.Threads)) - case MPTType: - state = C.load_mpt_model(C.CString(model), C.int(ops.Threads)) - } + state := C.load_gpt4all_model(C.CString(model), C.int(ops.Threads)) if state == nil { return nil, fmt.Errorf("failed loading model") @@ -62,7 +52,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) { } out := make([]byte, po.Tokens) - C.gptj_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize), + C.gpt4all_model_prompt(input, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.int(po.RepeatLastN), C.float(po.RepeatPenalty), C.int(po.ContextSize), C.int(po.Tokens), C.int(po.TopK), C.float(po.TopP), C.float(po.Temperature), C.int(po.Batch), C.float(po.ContextErase)) res := C.GoString((*C.char)(unsafe.Pointer(&out[0]))) @@ -75,7 +65,7 @@ func (l *Model) Predict(text string, opts ...PredictOption) (string, error) { } func (l *Model) Free() { - C.gptj_free_model(l.state) + C.gpt4all_free_model(l.state) } func (l *Model) SetTokenCallback(callback func(token string) bool) { diff --git a/gpt4all-bindings/golang/gpt4all_test.go b/gpt4all-bindings/golang/gpt4all_test.go index 1d99dd66..fd96584c 100644 --- a/gpt4all-bindings/golang/gpt4all_test.go +++ b/gpt4all-bindings/golang/gpt4all_test.go @@ -13,15 +13,5 @@ var _ = Describe("LLama binding", func() { Expect(err).To(HaveOccurred()) Expect(model).To(BeNil()) }) - It("fails with no model", func() { - model, err := New("not-existing", SetModelType(MPTType)) - Expect(err).To(HaveOccurred()) - Expect(model).To(BeNil()) - }) - It("fails with no model", func() { - model, err := New("not-existing", SetModelType(LLaMAType)) - Expect(err).To(HaveOccurred()) - Expect(model).To(BeNil()) - }) }) }) diff --git a/gpt4all-bindings/golang/options.go b/gpt4all-bindings/golang/options.go index 573f9abc..973d88e1 100644 --- a/gpt4all-bindings/golang/options.go +++ b/gpt4all-bindings/golang/options.go @@ -20,24 +20,14 @@ var DefaultOptions PredictOptions = PredictOptions{ } var DefaultModelOptions ModelOptions = ModelOptions{ - Threads: 4, - ModelType: GPTJType, + Threads: 4, } type ModelOptions struct { - Threads int - ModelType ModelType + Threads int } type ModelOption func(p *ModelOptions) -type ModelType int - -const ( - LLaMAType ModelType = 0 - GPTJType ModelType = iota - MPTType ModelType = iota -) - // SetTokens sets the number of tokens to generate. func SetTokens(tokens int) PredictOption { return func(p *PredictOptions) { @@ -110,13 +100,6 @@ func SetThreads(c int) ModelOption { } } -// SetModelType sets the model type. -func SetModelType(c ModelType) ModelOption { - return func(p *ModelOptions) { - p.ModelType = c - } -} - // Create a new PredictOptions object with the given options. func NewModelOptions(opts ...ModelOption) ModelOptions { p := DefaultModelOptions