Commit Graph

266 Commits

Author SHA1 Message Date
guillaume-be
1f4d344668
Change generate return type to Result (#437)
* - Changed the return type of generate method to be `Result`, removed fallible unwraps

* Fix doctests
2023-12-04 17:58:21 +00:00
guillaume-be
9f2cd17e91
tch 0.14.0 update (#435)
* updated tch version

* Addition of casting operation for cpu compat

* Fix ONNX resource path

* Fix GPT-J bias bool loading

* Updated changelog

* Fix Clippy warnings

* Updated readme
2023-11-26 09:02:00 +00:00
guillaume-be
fd1e66b1c7
Support for HF Tokenizers (#408)
* tokenizers output type conversion

* WIP hf Tokenizers support (2)

* finalize interface methods for hf tokenizers

* Addition of GPT2 example with hf tokenizers

* Made hf-tokenizers optional, added doc for HFTokenizer

* Addition of tests for hf tokenizers, addition to CI

* Updated changelog, extended documentation

* Fix Clippy warnings
2023-08-13 11:09:02 +01:00
guillaume-be
540c9268e7
ONNX Support (#346)
* Fixed Clippy warnings

* Revert "Shallow clone optimization (#243)"

This reverts commit ba584653bc.

* updated dependencies

* tryouts

* GPT2 tryouts

* WIP GPT2

* input mapping

* Cache storage

* Initial GPT2 prototype

* Initial ONNX Config and decoder implementation

* ONNXDecoder first draft

* Use Decoders in example

* Automated tch-ort conversion, decoder implementation

* ONNXCausalDecoder implementation

* Refactored _get_var_store to be optional, added get_device to gen trait

* updated example

* Added decoder_start_token_id to ConfigOption

* Addition of ONNXModelConfig, make max_position_embeddigs optional

* Addition of forward pass function for ONNXModel

* Working ONNX causal decoder

* Simplify tensor conversion

* refactor translation to facilitate ONNX integration

* Implementation of ONNXEncoder

* Implementation of ONNXConditionalGenerator

* working ONNXCausalGenerator

* - Reworked model resources type for pipelines and generators

* Aligned ONNXConditionalGenerator with other generators to use GenerateConfig for creation

* Moved force_token_id_generation to common utils function, fixed tests, Translation implementation

* generalized forced_bos and forced_eos tokens generation

* Aligned the `encode_prompt_text` method across language models

* Fix prompt encoding for causal generation

* Fix prompt encoding for causal generation

* Support for ONNX models for SequenceClassification

* Support for ONNX models for TokenClassification

* Support for ONNX models for POS and NER pipelines

* Support for ONNX models for ZeroShotClassification pipeline

* Support for ONNX models for QuestionAnswering pipeline

* Support for ONNX models for MaskedLM pipeline

* Added token_type_ids , updated layer cache i/o parsing for ONNX pipelines

* Support for ONNX models for TextGenerationPipeline, updated examples for remote resources

* Remove ONNX zero-shot classification example (lack of correct pretrained model)

* Addition of tests for ONNX pipelines support

* Made onnx feature optional

* Fix device lookup with onnx feature enabled

* Updates from main branch

* Flexible tokenizer creation for M2M100 (NLLB support), make NLLB test optional du to their size

* Fixed Clippy warnings

* Addition of documentation for ONNX

* Added documentation for ONNX support

* upcoming tch 1.12 fixes

* Fix merge conflicts

* Fix merge conflicts (2)

* Add download libtorch feature to ONNX tests

* Add download-onnx feature

* attempt to enable onnx download

* add remote resources feature

* onnx download

* pin ort version

* Update ort version
2023-05-30 07:20:25 +01:00
Matt Weber
ba57704c6f
Introduce in-memory resource abstraction (#375)
* Introduce in-memory resource abstraction

This follows from discussion in #366.

The goal of this change is to allow for weights to be loaded from a copy
of `rust_model.ot` that is already present in memory. There are two ways
in which that data might be present:

1. As a `HashMap<String, Tensor>` from previous interaction with `tch`
2. As a contiguous buffer of the file data

One or the other mechanism might be preferable depending on how user
code is using the model data. In some sense, implementing a provider
based on the second option is more of a convenience method for the user
to avoid the `tch::nn::VarStore::load_from_stream` interaction.

I've changed the definition of the `ResourceProvider` trait to require
that it be both `Send` and `Sync`. There are currently certain contexts
where `dyn ResourceProvider + Send` is required, but in theory before
this change an implementation might not be `Send` (or `Sync`). The
existing providers are both `Send` and `Sync`, and it seems reasonable
(if technically incorrect) for user code to assume this to be true. I
don't see a downside to making this explicit, but that part of this
change might be better suited for separate discussion. I am not trying
to sneak it in.

The `enum Resource` data type is used here as a means to abstract over
the possible ways a `ResourceProvider` might represent an underlying
resource. Without this, it would be necessary to either call different
trait methods until one succeeded or implement `as_any` and downcast in
order to implement `load_weights` similarly to how it is now. Those
options seemed less preferable to creating a wrapper.

While it would be possible to replace all calls to `get_local_path` with
the `get_resource` API, removal of the existing function would be a very
big breaking change. As such, this change also introduces
`RustBertError::UnsupportedError` to allow for the different methods to
coexist. An alternative would be for the new `ResourceProvider`s to
write their resources to a temporary disk location and return an
appropriate path, but that is counter to the purpose of the new
`ResourceProvider`s and so I chose not to implement that.

* - Add `impl<T: ResourceProvider + ?Sized> ResourceProvider for Box<T>`
- Remove `Resource::NamedTensors`
- Change `BufferResource` to contain a `&[u8]` rather than `Vec<u8>`

* Further rework proposal for resources

* Use mutable references and locks

* Make model resources mutable in tests/examples

* Remove unnecessary mutability and TensorResource references

* Add `BufferResource` example

---------

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
2023-05-26 18:23:28 +01:00
Joseph Hajduk
2bff63b2ee
updated tch-rs to 0.13.0 (#380)
* updated tch-rs to 0.13.0
find replaced of_slice to from_slice as per
008fff6cc0/CHANGELOG.md

* fixed formatting

* Add download feature and update CI

* add build script, update CI

* updated chanelog, readme, convert script

* fixed wrong position for build script

* added libtorch download to dependencies download test script

* args reordering

---------

Co-authored-by: josephhajduk <joseph@solidys.dev>
Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
2023-05-21 18:41:18 +01:00
guillaume-be
b05ec7b24f
Generation traits simplification (#339)
* - Remove LMHeadModel trait (integrate with PrivateLanguageGenerator)
- Simplify PrivateLanguageGenerator trait definition (no longer requires defined by objects implementing `LMHeadModel`, `Vocab` and `Tokenizer` traits)

* - Removed BART duplicated code, updated docs

* - Fixed BART-based model incorrect order of generation arguments

* - Updated changelog

* Fixed Clippy warning
2023-03-17 16:21:37 +00:00
Romain Leroux
c448862185
Add GPT-J support (#285) (#288)
* Add GPT-J support (#285)

* Improve GPT-J implementation

* Improve GPT-J tests

* Adapt GPT-J to latest master branch

* Specify how to convert GPT-J weights instead of providing them
2023-02-15 19:10:47 +00:00
guillaume-be
84561ec82b
Tokenizer special token map update (#330)
* Updates for compatibility with tokenizers special token rework

* Updated mask pipline methods

* Bumped version

* Fix clippy warnings
2023-01-30 17:53:18 +00:00
guillaume-be
0fc5ce6ad4
CodeBERT Pretrained models and examples (#322)
* Addition of Codebert examples

* Addition of CodeBERT pretrained models, CodeBERT example
2023-01-20 19:02:33 +00:00
guillaume-be
f12e8ef475
Aligned ModelForTokenClassification and ModelForSequenceClassification APIs (#323) 2023-01-15 11:10:38 +00:00
guillaume-be
fdf5503163
Fixed Clippy warnings (#309)
* - Fixed Clippy warnings
- Updated `tch` dependency
- Updated README to avoid confusion with respect to the required `LIBTORCH` version for the repository and published package versions

* Fixed Clippy warnings (2)

* Fixed Clippy warnings (3)
2022-12-21 17:52:26 +00:00
Vincent Xiao
dae899fea6
Add pipelines::masked_language and codebert support (#282)
* ad support for loading local moddel in SequenceClassificationConfig

* adjust config to match the SequenceClassificationConfig

* add piplines::masked_language

* add support and example for codebert

* provide an optional mask_token String field for asked_language pipline

* update example for masked_language pipeline

* codebert support revocation

* revoke support for loading local moddel

* solve conflicts

* update MaskedLanguageConfig

* fix doctest error in zero_shot_classification.rs

* MaskedLM pipeline updates

* fix multiple masked token, added test

* Updated changelog and docs

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
2022-12-21 16:58:02 +00:00
Anna Melnikov
a34cf9f8e4
Make predict methods in ZeroShot pipeline return Result instead of panicking on unwrap (#301)
* Add checked prediction methods

- Add checked prediction methods to ZeroShotClassificationModel.
These methods return Option and convert any underlying errors into None,
to allow callers to implement appropriate error handling logic.

* Update ZeroShot example to use checked method.

* Add tests for ZeroShot checked methods

* Change checked prediction methods to return Result

* refactor: rename *_checked into try_*

Rename *_checked methods into try_* methods.
This is more idiomatic vis-a-vis the Rust standard library.

* refactor: remove try_ prefix from predict methods

* refactor: change return from Option to Result

Change return type of ZeroShotClassificationModel.prepare_for_model
from option into Result. This simplifies the code, and returns
the error closer to its origin.

This addresses comments from @guillaume-be.

* refactor: address clippy lints in tests

Co-authored-by: guillaume-be <guillaume.becquin@gmail.com>
2022-12-04 09:10:01 +00:00
guillaume-be
05367b4df2
Make max_length optional (#296)
* Made `max_length` an optional argument for generation methods and pipelines

* Updated changelog
2022-11-15 19:20:51 +00:00
guillaume-be
5d2b107e99
Keyword/Keyphrase extraction (#295)
* stop word tokenizer implementation

* - Addition of all-mini-lm-l6-v2

* initial implementation of keyword scorer

* Cosine Similarity keyword extraction

* Added lower case parsing from tokenizer config for sentence embeddings

* Initial draft of pipeline complete

* Addition of Maximal Marginal relevance scorer

* Addition of Max Sum scorer

* Lowercase and ngrams handling

* Improved n-gram handling

* Skip n-grams containing stopwords

* Fixed short sentence input and added documentation

* Updated documentation and defaults, added example

* Addition of tests for keywords extractions

* Updated changelog

* Fixed Clippy warnings
2022-11-13 08:51:10 +00:00
guillaume-be
c6771d3992
Update to tch=0.9.0 (#293)
* Fixed short sentence input and added documentation

* Fixed Clippy warnings

* Updated CI Python version

* cleaner dim specification
2022-11-07 17:45:52 +00:00
guillaume-be
340be36ed9
Mixed resources (#291)
* - made `merges` resource optional for all pipelines
- allow mixing local and remote resources for pipelines

* Updated changelog

* Fixed Clippy warnings
2022-10-30 07:39:52 +00:00
guillaume-be
cce1e2707d
Prepare for 0.19 release (#272) 2022-07-25 06:36:02 +01:00
guillaume-be
a1595e6dfd
Updated sentence embeddings example (#263)
* Added conversion information for Distil-based sentence embedding models

* Fix Clippy warnings
2022-07-03 08:48:31 +01:00
Romain Leroux
4d8a298586
Add sbert implementation for inference (#250)
* Add sbert implementation for inference

* Fix clippy warnings

* Refactor sentence embeddings into a dedicated pipeline

* Add output_attentions and output_hidden_states to T5Config

* Add sbert implementation for inference

* Fix clippy warnings

* Refactor sentence embeddings into a dedicated pipeline

* Add output_attentions and output_hidden_states to T5Config

* Improve sentence_embeddings implementation

* Dedicated tokenizer config for strip_accents and add_prefix_space

* Rename forward to encode_as_tensor

* Remove _conf from Dense layer

* Add sentence embeddings docs

* Addition of remote resources and tests update

* Merge feature branch and fix doctests

* Add SentenceEmbeddingsBuilder<Remote> and improve remote resources

* Use tch::no_grad in sentence embeddings

* Updated changelog, registration of sentence embeddings integration tests

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
2022-06-21 20:24:09 +01:00
Jonas Hedman Engström
9b22c2482a
Refactor: Feature gate remote resource (#223)
* get_local_path as trait LocalPathProvider

* Remove config default impls

* Feature gate RemoteResource

* translation_builder refactoring to have remote fetching grouped

* Include dirs crate in remote feature gate

* Examples fixes

* Benches fixes

* Tests fix

* Remove Box from constructor parameters

* Fix examples no-Box

* Fix benches no-Box

* Fix tests no-Box

* Fix doc comment code

* Fix documentation `Resource` -> `ResourceProvider`

* moved remote local at same level

* moved ResourceProvider to resources mod

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
2022-02-25 21:24:03 +00:00
Flix
23c5d9112a
add async example, documentation and fix clippy (#217) 2022-01-30 11:51:58 +00:00
guillaume-be
e71712816e Addition of DeBERTa MNLI example 2021-12-12 20:14:36 +01:00
guillaume-be
4175942cc4
Fixed Clippy warnings (#204) 2021-12-09 09:33:27 +01:00
Guillaume Becquin
d84b2819d9 Merge remote-tracking branch 'origin/master' into entity_consolidation
# Conflicts:
#	src/pipelines/ner.rs
2021-11-20 11:03:05 +01:00
Guillaume Becquin
61e5d2d563 Addition of FNet model resource for sentiment analysis and registration in pipelines 2021-11-13 09:39:57 +01:00
Guillaume Becquin
73f017d0f7 Merge remote-tracking branch 'origin/master' into kind_reword
# Conflicts:
#	Cargo.toml
#	src/t5/layer_norm.rs
2021-11-09 16:00:21 +01:00
sftse
e297f395af
Make generics less generic. (#189)
* Make generics less generic.

Fix examples, tests and docs.

* Address outstanding issues

* Take less ownership where possible

* Fixup some clippy warnings

* Updated tokenizer crate version

Co-authored-by: Guillaume Becquin <guillaume.becquin@gmail.com>
2021-11-07 09:42:56 +01:00
Guillaume Becquin
de89e2d165 Updated XLNet for FP16 compatibility 2021-10-06 17:52:25 +02:00
Guillaume Becquin
889f509e6c Updated T5 for FP16 compatibility 2021-10-05 18:23:34 +02:00
Guillaume Becquin
fc2b2972f9 Updated Albert for Half precision support 2021-09-30 16:04:42 +02:00
Guillaume Becquin
72fabcdbd1 Updated GPT-Neo, working half precision greedy generation 2021-09-26 11:20:05 +02:00
guillaume-be
cb6bc34eb4 Updated borrowing for XLNet, integration tests 2021-08-20 11:08:37 +02:00
guillaume-be
3ff5199376 Updated offsets fixing overlapping spans 2021-08-18 09:07:23 +02:00
guillaume-be
9cadc5d15f Tested and fixed POS tagging on long inputs (requiring breaking input in multiple features) 2021-08-17 09:53:26 +02:00
Guillaume B
466c6b6922 Updated doctests 2021-07-28 18:10:20 +02:00
Guillaume B
ce90d8901d Updated examples and integration tests 2021-07-11 11:13:00 +02:00
Guillaume B
89b3a327fa Moved builder to own module, simplified Marian resource retrieval 2021-07-11 10:08:27 +02:00
Guillaume B
5dc7f33c39 Added MBart50 and M2M100 to supported translation models 2021-07-10 11:34:51 +02:00
Guillaume B
3b72b7cc9b Added example for translation builder, language checks for MBart and M2M100 2021-07-10 10:32:31 +02:00
Guillaume B
450fe0d533 Merge branch 'm2m100_implementation' into translation_rework 2021-07-09 15:41:28 +02:00
Guillaume B
85c05cbe13 Updated Marian translation example 2021-07-07 15:54:28 +02:00
Guillaume B
1c375a817e Use of new language enum in TranslationModel 2021-07-04 12:56:34 +02:00
Guillaume B
58eef0785f Merged master changes 2021-06-28 18:57:27 +02:00
Guillaume B
0b2e339e87 Merge remote-tracking branch 'origin/master' into m2m100_implementation
# Conflicts:
#	Cargo.toml
2021-06-28 18:53:46 +02:00
Guillaume B
2f6b26bb88 Addition of tests for M2M100 2021-06-27 18:18:41 +02:00
Guillaume B
f024350dee Fixed various documentation typos 2021-06-26 11:07:17 +02:00
Guillaume B
9a04d1527a Working example for M2M100 Translation 2021-06-26 10:49:42 +02:00
Guillaume B
f29e02ecbc Addition of TextOutput and IndicesOutput, updated pipelines and tests 2021-06-16 18:15:22 +02:00