From 85dfdcd90ea6e0298f8970361def033e8c47d9b3 Mon Sep 17 00:00:00 2001 From: Silas Marvin <19626586+SilasMarvin@users.noreply.github.com> Date: Sun, 14 Jan 2024 14:41:38 -0800 Subject: [PATCH] Added some vscode stuff and got it working nicely with python --- .gitignore | 2 + .vscode/launch.json | 22 ++ .vscode/task.json | 17 + .vscode/tasks.json | 17 + Cargo.lock | 618 +++++-------------------------- Cargo.toml | 13 +- build.rs | 51 +++ editors/vscode/package-lock.json | 146 ++++++++ editors/vscode/package.json | 23 ++ editors/vscode/src/index.ts | 49 +++ editors/vscode/tsconfig.json | 13 + run.sh | 3 - src/main.rs | 83 ++++- src/models/mod.rs | 32 -- src/models/starcoder.rs | 119 ------ src/python/transformers.py | 44 +++ 16 files changed, 539 insertions(+), 713 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 .vscode/task.json create mode 100644 .vscode/tasks.json create mode 100644 build.rs create mode 100644 editors/vscode/package-lock.json create mode 100644 editors/vscode/package.json create mode 100644 editors/vscode/src/index.ts create mode 100644 editors/vscode/tsconfig.json delete mode 100755 run.sh delete mode 100644 src/models/mod.rs delete mode 100644 src/models/starcoder.rs create mode 100644 src/python/transformers.py diff --git a/.gitignore b/.gitignore index 12ab8d8..a7196cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ /target /models +node_modules +out diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..e800ed4 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,22 @@ +{ + "version": "2.0.0", + "configurations": [ + { + "name": "Run Installed Extension", + "type": "extensionHost", + "request": "launch", + "runtimeExecutable": "${execPath}", + "args": [ + "--disable-extensions", + "--extensionDevelopmentPath=${workspaceFolder}/editors/vscode" + ], + "outFiles": [ + "${workspaceFolder}/editors/vscode/out/**/*.js" + ], + "preLaunchTask": "Build Extension", + "skipFiles": [ + "/**/*.js" + ] + } + ] +} diff --git a/.vscode/task.json b/.vscode/task.json new file mode 100644 index 0000000..9e5219c --- /dev/null +++ b/.vscode/task.json @@ -0,0 +1,17 @@ +{ + "version": "0.1.0", + "tasks": [ + { + "label": "Build Extension", + "group": "build", + "type": "npm", + "script": "build", + "path": "editors/vscode/", + "problemMatcher": { + "base": "$tsc-watch", + "fileLocation": ["relative", "${workspaceFolder}/editors/vscode/"] + }, + "isBackground": true + } + ] +} diff --git a/.vscode/tasks.json b/.vscode/tasks.json new file mode 100644 index 0000000..e810e1c --- /dev/null +++ b/.vscode/tasks.json @@ -0,0 +1,17 @@ +{ + "version": "2.0.0", + "tasks": [ + { + "label": "Build Extension", + "group": "build", + "type": "npm", + "script": "build", + "path": "editors/vscode/", + "problemMatcher": { + "base": "$tsc-watch", + "fileLocation": ["relative", "${workspaceFolder}/editors/vscode/"] + }, + "isBackground": true + } + ] +} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 5273441..0d6d8c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,15 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "addr2line" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" -dependencies = [ - "gimli", -] - [[package]] name = "adler" version = "1.0.2" @@ -79,9 +70,6 @@ name = "anyhow" version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" -dependencies = [ - "backtrace", -] [[package]] name = "autocfg" @@ -89,21 +77,6 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" -[[package]] -name = "backtrace" -version = "0.3.69" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" -dependencies = [ - "addr2line", - "cc", - "cfg-if", - "libc", - "miniz_oxide", - "object", - "rustc-demangle", -] - [[package]] name = "base64" version = "0.13.1" @@ -128,92 +101,12 @@ version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" -[[package]] -name = "bytemuck" -version = "1.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" -dependencies = [ - "bytemuck_derive", -] - -[[package]] -name = "bytemuck_derive" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - [[package]] name = "byteorder" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" -[[package]] -name = "candle-core" -version = "0.3.1" -dependencies = [ - "byteorder", - "candle-kernels", - "cudarc", - "gemm", - "half", - "memmap2", - "num-traits", - "num_cpus", - "rand", - "rand_distr", - "rayon", - "safetensors", - "thiserror", - "yoke", - "zip", -] - -[[package]] -name = "candle-kernels" -version = "0.3.1" -dependencies = [ - "anyhow", - "glob", - "rayon", -] - -[[package]] -name = "candle-nn" -version = "0.3.1" -dependencies = [ - "candle-core", - "half", - "num-traits", - "rayon", - "safetensors", - "serde", - "thiserror", -] - -[[package]] -name = "candle-transformers" -version = "0.3.1" -dependencies = [ - "byteorder", - "candle-core", - "candle-nn", - "num-traits", - "rand", - "rayon", - "serde", - "serde_json", - "serde_plain", - "tracing", - "wav", -] - [[package]] name = "cc" version = "1.0.83" @@ -356,21 +249,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crunchy" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" - -[[package]] -name = "cudarc" -version = "0.9.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1871a911a2b9a3f66a285896a719159985683bf9903aa2cf89e0c9f53e14552" -dependencies = [ - "half", -] - [[package]] name = "darling" version = "0.14.4" @@ -437,6 +315,15 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "directories" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35" +dependencies = [ + "dirs-sys", +] + [[package]] name = "dirs" version = "5.0.1" @@ -458,16 +345,6 @@ dependencies = [ "windows-sys 0.48.0", ] -[[package]] -name = "dyn-stack" -version = "0.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" -dependencies = [ - "bytemuck", - "reborrow", -] - [[package]] name = "either" version = "1.9.0" @@ -545,123 +422,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "gemm" -version = "0.16.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b3afa707040531a7527477fd63a81ea4f6f3d26037a2f96776e57fb843b258e" -dependencies = [ - "dyn-stack", - "gemm-c32", - "gemm-c64", - "gemm-common", - "gemm-f16", - "gemm-f32", - "gemm-f64", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "seq-macro", -] - -[[package]] -name = "gemm-c32" -version = "0.16.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1cc3973a4c30c73f26a099113953d0c772bb17ee2e07976c0a06b8fe1f38a57d" -dependencies = [ - "dyn-stack", - "gemm-common", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "seq-macro", -] - -[[package]] -name = "gemm-c64" -version = "0.16.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30362894b93dada374442cb2edf4512ddf19513c9bec88e06a445bcb6b22e64f" -dependencies = [ - "dyn-stack", - "gemm-common", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "seq-macro", -] - -[[package]] -name = "gemm-common" -version = "0.16.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "988499faa80566b046b4fee2c5f15af55b5a20c1fe8486b112ebb34efa045ad6" -dependencies = [ - "bytemuck", - "dyn-stack", - "half", - "num-complex", - "num-traits", - "once_cell", - "paste", - "pulp", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "gemm-f16" -version = "0.16.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6cf2854a12371684c38d9a865063a27661812a3ff5803454c5742e8f5a388ce" -dependencies = [ - "dyn-stack", - "gemm-common", - "gemm-f32", - "half", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "gemm-f32" -version = "0.16.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bc84003cf6d950a7c7ca714ad6db281b6cef5c7d462f5cd9ad90ea2409c7227" -dependencies = [ - "dyn-stack", - "gemm-common", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "seq-macro", -] - -[[package]] -name = "gemm-f64" -version = "0.16.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35187ef101a71eed0ecd26fb4a6255b4192a12f1c5335f3a795698f2d9b6cf33" -dependencies = [ - "dyn-stack", - "gemm-common", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "seq-macro", -] - [[package]] name = "getrandom" version = "0.2.11" @@ -673,44 +433,12 @@ dependencies = [ "wasi", ] -[[package]] -name = "gimli" -version = "0.28.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" - -[[package]] -name = "glob" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" - -[[package]] -name = "half" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" -dependencies = [ - "bytemuck", - "cfg-if", - "crunchy", - "num-traits", - "rand", - "rand_distr", -] - [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -[[package]] -name = "hermit-abi" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" - [[package]] name = "hf-hub" version = "0.3.2" @@ -756,6 +484,12 @@ dependencies = [ "unicode-width", ] +[[package]] +name = "indoc" +version = "2.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e186cfbae8084e513daff4240b4797e342f988cecda4fb6c939150f96315fd8" + [[package]] name = "instant" version = "0.1.12" @@ -792,12 +526,6 @@ version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" -[[package]] -name = "libm" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" - [[package]] name = "libredox" version = "0.0.1" @@ -836,13 +564,13 @@ name = "lsp-ai" version = "0.1.0" dependencies = [ "anyhow", - "candle-core", - "candle-nn", - "candle-transformers", + "directories", "hf-hub", "lsp-server", "lsp-types", + "once_cell", "parking_lot", + "pyo3", "rand", "ropey", "serde", @@ -897,16 +625,6 @@ version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" -[[package]] -name = "memmap2" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" -dependencies = [ - "libc", - "stable_deref_trait", -] - [[package]] name = "memoffset" version = "0.9.0" @@ -980,56 +698,17 @@ dependencies = [ "minimal-lexical", ] -[[package]] -name = "num-complex" -version = "0.4.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" -dependencies = [ - "bytemuck", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" -dependencies = [ - "autocfg", - "libm", -] - -[[package]] -name = "num_cpus" -version = "1.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" -dependencies = [ - "hermit-abi", - "libc", -] - [[package]] name = "number_prefix" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" -[[package]] -name = "object" -version = "0.32.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" -dependencies = [ - "memchr", -] - [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "onig" @@ -1138,12 +817,6 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" -[[package]] -name = "pin-project-lite" -version = "0.2.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" - [[package]] name = "pkg-config" version = "0.3.27" @@ -1172,15 +845,64 @@ dependencies = [ ] [[package]] -name = "pulp" -version = "0.18.6" +name = "pyo3" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16785ee69419641c75affff7c9fdbdb7c0ab26dc9a5fb5218c2a2e9e4ef2087d" +checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" dependencies = [ - "bytemuck", - "libm", - "num-complex", - "reborrow", + "cfg-if", + "indoc", + "libc", + "memoffset", + "parking_lot", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.38", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.38", ] [[package]] @@ -1222,25 +944,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "rand_distr" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" -dependencies = [ - "num-traits", - "rand", -] - -[[package]] -name = "raw-cpuid" -version = "10.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" -dependencies = [ - "bitflags 1.3.2", -] - [[package]] name = "rayon" version = "1.8.0" @@ -1272,12 +975,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "reborrow" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" - [[package]] name = "redox_syscall" version = "0.4.1" @@ -1333,12 +1030,6 @@ version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" -[[package]] -name = "riff" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9b1a3d5f46d53f4a3478e2be4a5a5ce5108ea58b100dcd139830eae7f79a3a1" - [[package]] name = "ring" version = "0.17.5" @@ -1363,12 +1054,6 @@ dependencies = [ "str_indices", ] -[[package]] -name = "rustc-demangle" -version = "0.1.23" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" - [[package]] name = "rustix" version = "0.38.25" @@ -1410,16 +1095,6 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" -[[package]] -name = "safetensors" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" -dependencies = [ - "serde", - "serde_json", -] - [[package]] name = "schannel" version = "0.1.22" @@ -1468,12 +1143,6 @@ dependencies = [ "libc", ] -[[package]] -name = "seq-macro" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" - [[package]] name = "serde" version = "1.0.190" @@ -1505,15 +1174,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_plain" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" -dependencies = [ - "serde", -] - [[package]] name = "serde_repr" version = "0.1.17" @@ -1560,12 +1220,6 @@ dependencies = [ "unicode-segmentation", ] -[[package]] -name = "stable_deref_trait" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" - [[package]] name = "str_indices" version = "0.4.3" @@ -1601,16 +1255,10 @@ dependencies = [ ] [[package]] -name = "synstructure" -version = "0.13.0" +name = "target-lexicon" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "285ba80e733fac80aa4270fbcdf83772a79b80aa35c97075320abfee4a915b06" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", - "unicode-xid", -] +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "tempfile" @@ -1693,37 +1341,6 @@ dependencies = [ "unicode_categories", ] -[[package]] -name = "tracing" -version = "0.1.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" -dependencies = [ - "pin-project-lite", - "tracing-attributes", - "tracing-core", -] - -[[package]] -name = "tracing-attributes" -version = "0.1.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", -] - -[[package]] -name = "tracing-core" -version = "0.1.32" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" -dependencies = [ - "once_cell", -] - [[package]] name = "unicode-bidi" version = "0.3.13" @@ -1766,18 +1383,18 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" -[[package]] -name = "unicode-xid" -version = "0.2.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" - [[package]] name = "unicode_categories" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + [[package]] name = "untrusted" version = "0.9.0" @@ -1834,15 +1451,6 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" -[[package]] -name = "wav" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a65e199c799848b4f997072aa4d673c034f80f40191f97fe2f0a23f410be1609" -dependencies = [ - "riff", -] - [[package]] name = "webpki-roots" version = "0.25.3" @@ -2002,59 +1610,3 @@ name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" - -[[package]] -name = "yoke" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "65e71b2e4f287f467794c671e2b8f8a5f3716b3c829079a1c44740148eff07e4" -dependencies = [ - "serde", - "stable_deref_trait", - "yoke-derive", - "zerofrom", -] - -[[package]] -name = "yoke-derive" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e6936f0cce458098a201c245a11bef556c6a0181129c7034d10d76d1ec3a2b8" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", - "synstructure", -] - -[[package]] -name = "zerofrom" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "655b0814c5c0b19ade497851070c640773304939a6c0fd5f5fb43da0696d05b7" -dependencies = [ - "zerofrom-derive", -] - -[[package]] -name = "zerofrom-derive" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.38", - "synstructure", -] - -[[package]] -name = "zip" -version = "0.6.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" -dependencies = [ - "byteorder", - "crc32fast", - "crossbeam-utils", -] diff --git a/Cargo.toml b/Cargo.toml index 377eb31..0635ef6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,17 +12,16 @@ lsp-types = "0.94.1" ropey = "1.6.1" serde = "1.0.190" serde_json = "1.0.108" -# candle-core = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] } -# candle-nn = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] } -# candle-transformers = { git = "https://github.com/huggingface/candle/", version = "0.3.1", features = ["accelerate"] } -candle-core = { path = "../candle/candle-core" } -candle-nn = { path = "../candle/candle-nn" } -candle-transformers = { path = "../candle/candle-transformers" } hf-hub = { git = "https://github.com/huggingface/hf-hub", version = "0.3.2" } rand = "0.8.5" tokenizers = "0.14.1" parking_lot = "0.12.1" +once_cell = "1.19.0" +pyo3 = { version = "0.20.2", features = ["auto-initialize"] } +directories = "5.0.1" + +[build-dependencies] +directories = "5.0.1" [features] default = [] -cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..d6daff4 --- /dev/null +++ b/build.rs @@ -0,0 +1,51 @@ +use directories::ProjectDirs; +use std::fs; +use std::process::Command; + +fn main() { + let version = env!("CARGO_PKG_VERSION"); + eprintln!("Building lsp-ai - {version}"); + + // Create the project_dir + let project_dir = ProjectDirs::from("", "", "lsp-ai").expect("getting project directory"); + let config_dir = project_dir.config_dir(); + if !config_dir.exists() { + fs::create_dir(&config_dir).unwrap_or_else(|e| panic!("creating {config_dir:?} - {e}")); + } + + // Construct the venv + let venv_path = config_dir.join("venv"); + let output = Command::new("virtualenv") + .args([venv_path.as_os_str()]) + .args(["--clear"]) + .output() + .expect("running virtualenv command"); + if !output.status.success() { + eprintln!( + "{}", + String::from_utf8(output.stdout).expect("converting stdout to string") + ); + eprintln!( + "{}", + String::from_utf8(output.stderr).expect("converting stdout to string") + ); + } + + // Install the python dependencies + let pip_path = venv_path.join("bin").join("pip"); + let output = Command::new(pip_path.as_os_str()) + .arg("install") + .arg("llama-cpp-python") + .output() + .expect("running pip install"); + if !output.status.success() { + eprintln!( + "{}", + String::from_utf8(output.stdout).expect("converting stdout to string") + ); + eprintln!( + "{}", + String::from_utf8(output.stderr).expect("converting stdout to string") + ); + } +} diff --git a/editors/vscode/package-lock.json b/editors/vscode/package-lock.json new file mode 100644 index 0000000..14d7e8a --- /dev/null +++ b/editors/vscode/package-lock.json @@ -0,0 +1,146 @@ +{ + "name": "lsp-ai", + "version": "0.1.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "lsp-ai", + "version": "0.1.0", + "license": "MIT", + "dependencies": { + "@types/vscode": "^1.85.0", + "vscode-languageclient": "^9.0.1" + }, + "devDependencies": { + "@types/node": "^20.11.0", + "typescript": "^5.3.3" + }, + "engines": { + "vscode": "^1.75.0" + } + }, + "node_modules/@types/node": { + "version": "20.11.0", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.11.0.tgz", + "integrity": "sha512-o9bjXmDNcF7GbM4CNQpmi+TutCgap/K3w1JyKgxAjqx41zp9qlIAVFi0IhCNsJcXolEqLWhbFbEeL0PvYm4pcQ==", + "dev": true, + "dependencies": { + "undici-types": "~5.26.4" + } + }, + "node_modules/@types/vscode": { + "version": "1.85.0", + "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.85.0.tgz", + "integrity": "sha512-CF/RBon/GXwdfmnjZj0WTUMZN5H6YITOfBCP4iEZlOtVQXuzw6t7Le7+cR+7JzdMrnlm7Mfp49Oj2TuSXIWo3g==" + }, + "node_modules/balanced-match": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/balanced-match/-/balanced-match-1.0.2.tgz", + "integrity": "sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==" + }, + "node_modules/brace-expansion": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.1.tgz", + "integrity": "sha512-XnAIvQ8eM+kC6aULx6wuQiwVsnzsi9d3WxzV3FpWTGA19F621kwdbsAcFKXgKUHZWsy+mY6iL1sHTxWEFCytDA==", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/lru-cache": { + "version": "6.0.0", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-6.0.0.tgz", + "integrity": "sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==", + "dependencies": { + "yallist": "^4.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/minimatch": { + "version": "5.1.6", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-5.1.6.tgz", + "integrity": "sha512-lKwV/1brpG6mBUFHtb7NUmtABCb2WZZmm2wNiOA5hAb8VdCS4B3dtMWyvcoViccwAW/COERjXLt0zP1zXUN26g==", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/semver": { + "version": "7.5.4", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.5.4.tgz", + "integrity": "sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==", + "dependencies": { + "lru-cache": "^6.0.0" + }, + "bin": { + "semver": "bin/semver.js" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/typescript": { + "version": "5.3.3", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.3.3.tgz", + "integrity": "sha512-pXWcraxM0uxAS+tN0AG/BF2TyqmHO014Z070UsJ+pFvYuRSq8KH8DmWpnbXe0pEPDHXZV3FcAbJkijJ5oNEnWw==", + "dev": true, + "bin": { + "tsc": "bin/tsc", + "tsserver": "bin/tsserver" + }, + "engines": { + "node": ">=14.17" + } + }, + "node_modules/undici-types": { + "version": "5.26.5", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", + "integrity": "sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==", + "dev": true + }, + "node_modules/vscode-jsonrpc": { + "version": "8.2.0", + "resolved": "https://registry.npmjs.org/vscode-jsonrpc/-/vscode-jsonrpc-8.2.0.tgz", + "integrity": "sha512-C+r0eKJUIfiDIfwJhria30+TYWPtuHJXHtI7J0YlOmKAo7ogxP20T0zxB7HZQIFhIyvoBPwWskjxrvAtfjyZfA==", + "engines": { + "node": ">=14.0.0" + } + }, + "node_modules/vscode-languageclient": { + "version": "9.0.1", + "resolved": "https://registry.npmjs.org/vscode-languageclient/-/vscode-languageclient-9.0.1.tgz", + "integrity": "sha512-JZiimVdvimEuHh5olxhxkht09m3JzUGwggb5eRUkzzJhZ2KjCN0nh55VfiED9oez9DyF8/fz1g1iBV3h+0Z2EA==", + "dependencies": { + "minimatch": "^5.1.0", + "semver": "^7.3.7", + "vscode-languageserver-protocol": "3.17.5" + }, + "engines": { + "vscode": "^1.82.0" + } + }, + "node_modules/vscode-languageserver-protocol": { + "version": "3.17.5", + "resolved": "https://registry.npmjs.org/vscode-languageserver-protocol/-/vscode-languageserver-protocol-3.17.5.tgz", + "integrity": "sha512-mb1bvRJN8SVznADSGWM9u/b07H7Ecg0I3OgXDuLdn307rl/J3A9YD6/eYOssqhecL27hK1IPZAsaqh00i/Jljg==", + "dependencies": { + "vscode-jsonrpc": "8.2.0", + "vscode-languageserver-types": "3.17.5" + } + }, + "node_modules/vscode-languageserver-types": { + "version": "3.17.5", + "resolved": "https://registry.npmjs.org/vscode-languageserver-types/-/vscode-languageserver-types-3.17.5.tgz", + "integrity": "sha512-Ld1VelNuX9pdF39h2Hgaeb5hEZM2Z3jUrrMgWQAu82jMtZp7p3vJT3BzToKtZI7NgQssZje5o0zryOrhQvzQAg==" + }, + "node_modules/yallist": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/yallist/-/yallist-4.0.0.tgz", + "integrity": "sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==" + } + } +} diff --git a/editors/vscode/package.json b/editors/vscode/package.json new file mode 100644 index 0000000..8c53371 --- /dev/null +++ b/editors/vscode/package.json @@ -0,0 +1,23 @@ +{ + "name": "lsp-ai", + "version": "0.1.0", + "description": "", + "main": "/out/index.js", + "scripts": { + "build": "npx tsc" + }, + "author": "", + "license": "MIT", + "activationEvents": ["onLanguage"], + "engines": { + "vscode": "^1.75.0" + }, + "devDependencies": { + "@types/node": "^20.11.0", + "typescript": "^5.3.3" + }, + "dependencies": { + "@types/vscode": "^1.85.0", + "vscode-languageclient": "^9.0.1" + } +} diff --git a/editors/vscode/src/index.ts b/editors/vscode/src/index.ts new file mode 100644 index 0000000..06b7a68 --- /dev/null +++ b/editors/vscode/src/index.ts @@ -0,0 +1,49 @@ +import { workspace, ExtensionContext } from 'vscode'; + +import { + LanguageClient, + LanguageClientOptions, + ServerOptions, + TransportKind +} from 'vscode-languageclient/node'; + +let client: LanguageClient; + +export function activate(_context: ExtensionContext) { + console.log("\n\nIN THE ACTIVATE FUNCTION\n\n"); + + // Configure the server options + let serverOptions: ServerOptions = { + command: "lsp-ai", + transport: TransportKind.stdio, + }; + + // Options to control the language client + let clientOptions: LanguageClientOptions = { + documentSelector: [{ scheme: 'file', language: 'python' }], + synchronize: { + // Notify the server about file changes to '.clientrc files contained in the workspace + fileEvents: workspace.createFileSystemWatcher('**/.clientrc') + } + }; + + // Create the language client and start the client + client = new LanguageClient( + 'lsp-ai', + 'lsp-ai', + serverOptions, + clientOptions + ); + + console.log("\n\nSTARTING THE CLIENT\n\n"); + + // Start the client. This will also launch the server + client.start(); +} + +export function deactivate(): Thenable | undefined { + if (!client) { + return undefined; + } + return client.stop(); +} diff --git a/editors/vscode/tsconfig.json b/editors/vscode/tsconfig.json new file mode 100644 index 0000000..926f96c --- /dev/null +++ b/editors/vscode/tsconfig.json @@ -0,0 +1,13 @@ +{ + "compilerOptions": { + "module": "Node16", + "moduleResolution": "Node16", + "target": "ES2021", + "outDir": "out", + "lib": ["ES2021"], + "sourceMap": true, + "rootDir": "src", + }, + "exclude": ["node_modules"], + "include": ["src", "tests"] +} diff --git a/run.sh b/run.sh deleted file mode 100755 index 82f9831..0000000 --- a/run.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash - -/home/silas/Projects/lsp-ai/target/release/lsp-ai diff --git a/src/main.rs b/src/main.rs index 9bb8002..41d699c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -14,8 +14,15 @@ use std::collections::HashMap; use std::sync::Arc; use std::thread; -mod models; -use models::{Model, ModelParams}; +use once_cell::sync::Lazy; +use pyo3::prelude::*; + +pub static PY_MODULE: Lazy>> = Lazy::new(|| { + pyo3::Python::with_gil(|py| -> Result> { + let src = include_str!("python/transformers.py"); + Ok(pyo3::types::PyModule::from_code(py, src, "transformers.py", "transformers")?.into()) + }) +}); // Taken directly from: https://github.com/rust-lang/rust-analyzer fn notification_is(notification: &Notification) -> bool { @@ -40,16 +47,25 @@ fn main() -> Result<()> { ..Default::default() })?; let initialization_params = connection.initialize(server_capabilities)?; + + // Activate the python venv + Python::with_gil(|py| -> Result<()> { + let activate: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "activate_venv")?; + + activate.call1(py, ("/Users/silas/Projects/lsp-ai/venv",))?; + Ok(()) + })?; + main_loop(connection, initialization_params)?; io_threads.join()?; Ok(()) } #[derive(Deserialize)] -struct Params { - // We may want to put other non-model related parameters here in the future - model_params: Option, -} +struct Params {} struct CompletionRequest { id: RequestId, @@ -69,6 +85,16 @@ impl CompletionRequest { fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { let params: Params = serde_json::from_value(params)?; + // Set the model + Python::with_gil(|py| -> Result<()> { + let activate: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "set_model")?; + activate.call1(py, ("",))?; + Ok(()) + })?; + // Prep variables let connection = Arc::new(connection); let mut file_map: HashMap = HashMap::new(); @@ -79,12 +105,7 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { // Thread local variables let thread_last_completion_request = last_completion_request.clone(); let thread_connection = connection.clone(); - // We need to allow unreachabel to be able to use the question mark operators here - // We could probably restructure this to not require it - #[allow(unreachable_code)] thread::spawn(move || { - // Build the model from the params - let mut model: Box = params.model_params.unwrap_or_default().try_into()?; loop { // I think we need this drop, not 100% sure though let mut completion_request = thread_last_completion_request.lock(); @@ -98,16 +119,39 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { { let filter_text = rope .get_line(params.text_document_position.position.line as usize) - .context("Error getting line with ropey")? + .expect("Error getting line with ropey") .to_string(); // Convert rope to correct prompt for llm - let start_index = rope + let cursor_index = rope .line_to_char(params.text_document_position.position.line as usize) + params.text_document_position.position.character as usize; - rope.insert(start_index, ""); - let prompt = format!("{}", rope); - let insert_text = model.run(&prompt)?; + + // We will want to have some kind of infill support we add + // rope.insert(cursor_index, "<|fim_hole|>"); + // rope.insert(0, "<|fim_start|>"); + // rope.insert(rope.len_chars(), "<|fim_end|>"); + // let prompt = rope.to_string(); + + let prompt = rope + .get_slice((0..cursor_index)) + .expect("Error getting rope slice") + .to_string(); + + eprintln!("\n\n****{prompt}****\n\n"); + + let insert_text = Python::with_gil(|py| -> Result { + let transform: Py = PY_MODULE + .as_ref() + .map_err(anyhow::Error::msg)? + .getattr(py, "transform")?; + + let out: String = transform.call1(py, (prompt,))?.extract(py)?; + Ok(out) + }) + .expect("Error during transform"); + + eprintln!("\n{insert_text}\n"); // Create and return the completion let completion_text_edit = TextEdit::new( @@ -141,11 +185,13 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { result: Some(result), error: None, }; - thread_connection.sender.send(Message::Response(resp))?; + thread_connection + .sender + .send(Message::Response(resp)) + .expect("Error sending response"); } thread::sleep(std::time::Duration::from_millis(5)); } - anyhow::Ok(()) }); for msg in &connection.receiver { @@ -171,7 +217,6 @@ fn main_loop(connection: Connection, params: serde_json::Value) -> Result<()> { }; } Message::Notification(not) => { - eprintln!("got notification: {not:?}"); if notification_is::(¬) { let params: DidOpenTextDocumentParams = serde_json::from_value(not.params)?; let rope = Rope::from_str(¶ms.text_document.text); diff --git a/src/models/mod.rs b/src/models/mod.rs deleted file mode 100644 index 5a691bf..0000000 --- a/src/models/mod.rs +++ /dev/null @@ -1,32 +0,0 @@ -use anyhow::Result; -use serde::Deserialize; - -mod starcoder; - -pub trait Model { - fn run(&mut self, prompt: &str) -> Result; -} - -#[derive(Deserialize, Default)] -pub struct ModelParams { - model: Option, - model_file: Option, - model_type: Option, - max_length: Option, -} - -impl TryFrom for Box { - type Error = anyhow::Error; - - fn try_from(value: ModelParams) -> Result { - let model_type = value.model_type.unwrap_or("starcoder".to_string()); - let max_length = value.max_length.unwrap_or(12); - Ok(Box::new(match model_type.as_str() { - "starcoder" => starcoder::build_model(value.model, value.model_file, max_length)?, - _ => anyhow::bail!( - "Model type: {} not supported. Feel free to make a pr or create a github issue.", - model_type - ), - })) - } -} diff --git a/src/models/starcoder.rs b/src/models/starcoder.rs deleted file mode 100644 index 2cc79c6..0000000 --- a/src/models/starcoder.rs +++ /dev/null @@ -1,119 +0,0 @@ -use anyhow::{Error as E, Result}; -use candle_core::{DType, Device, Tensor}; -use candle_nn::VarBuilder; -use candle_transformers::generation::LogitsProcessor; -use candle_transformers::models::bigcode::{Config, GPTBigCode}; -use hf_hub::{api::sync::ApiBuilder, Repo, RepoType}; -use tokenizers::Tokenizer; - -pub struct Model { - model: GPTBigCode, - device: Device, - tokenizer: Tokenizer, - logits_processor: LogitsProcessor, - max_length: usize, -} - -impl super::Model for Model { - fn run(&mut self, prompt: &str) -> Result { - eprintln!("Starting to generate tokens"); - let mut tokens = self - .tokenizer - .encode(prompt, true) - .map_err(E::msg)? - .get_ids() - .to_vec(); - let mut new_tokens = vec![]; - let mut outputs = vec![]; - let start_gen = std::time::Instant::now(); - for index in 0..self.max_length { - let (context_size, past_len) = if self.model.config().use_cache && index > 0 { - (1, tokens.len().saturating_sub(1)) - } else { - (tokens.len(), 0) - }; - let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; - let logits = self.model.forward(&input, past_len)?; - let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; - - let next_token = self.logits_processor.sample(&logits)?; - tokens.push(next_token); - new_tokens.push(next_token); - let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?; - outputs.push(token); - } - let dt = start_gen.elapsed(); - self.model.clear_cache(); - eprintln!( - "GENERATED {} tokens in {} milliseconds", - outputs.len(), - dt.as_millis() - ); - Ok(outputs.join("")) - } -} - -impl Model { - fn new( - model: GPTBigCode, - tokenizer: Tokenizer, - seed: u64, - temp: Option, - top_p: Option, - device: &Device, - max_length: usize, - ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp, top_p); - Self { - model, - tokenizer, - logits_processor, - device: device.clone(), - max_length, - } - } -} - -pub fn build_model( - model: Option, - model_file: Option, - max_length: usize, -) -> Result { - let start = std::time::Instant::now(); - eprintln!("Loading in model"); - let api = ApiBuilder::new() - .with_token(Some(std::env::var("HF_TOKEN")?.to_string())) - .build()?; - let repo = api.repo(Repo::with_revision( - "bigcode/starcoderbase-1b".to_string(), - RepoType::Model, - "main".to_string(), - )); - let tokenizer_filename = repo.get("tokenizer.json")?; - let filenames = ["model.safetensors"] - .iter() - .map(|f| repo.get(f)) - .collect::, _>>()?; - let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - - // Set the device - #[cfg(feature = "cuda")] - let device = Device::new_cuda(0)?; - #[cfg(not(feature = "cuda"))] - let device = Device::Cpu; - - let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let config = Config::starcoder_1b(); - let model = GPTBigCode::load(vb, config)?; - eprintln!("loaded the model in {:?}", start.elapsed()); - Ok(Model::new( - model, - tokenizer, - 0, - Some(0.85), - None, - &device, - max_length, - )) -} diff --git a/src/python/transformers.py b/src/python/transformers.py new file mode 100644 index 0000000..de2daeb --- /dev/null +++ b/src/python/transformers.py @@ -0,0 +1,44 @@ +import sys +import os + +from llama_cpp import Llama + + +model = None + + +def activate_venv(venv): + if sys.platform in ('win32', 'win64', 'cygwin'): + activate_this = os.path.join(venv, 'Scripts', 'activate_this.py') + else: + activate_this = os.path.join(venv, 'bin', 'activate_this.py') + + if os.path.exists(activate_this): + exec(open(activate_this).read(), dict(__file__=activate_this)) + return True + else: + print(f"Virtualenv not found: {venv}", file=sys.stderr) + return False + + + +def set_model(filler): + global model + model = Llama( + # model_path="./tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf", # Download the model file first + model_path="/Users/silas/Projects/Tests/lsp-ai-tests/deepseek-coder-6.7b-base.Q4_K_M.gguf", # Download the model file first + n_ctx=2048, # The max sequence length to use - note that longer sequence lengths require much more resources + n_threads=8, # The number of CPU threads to use, tailor to your system and the resulting performance + n_gpu_layers=35 # The number of layers to offload to GPU, if you have GPU acceleration available + ) + + +def transform(input): + # Simple inference example + output = model( + input, # Prompt + max_tokens=32, # Generate up to 512 tokens + stop=["<|EOT|>"], # Example stop token - not necessarily correct for this specific model! Please check before using. + echo=False # Whether to echo the prompt + ) + return output["choices"][0]["text"]