The prefix function has been removed

This commit is contained in:
kqlio67 2024-10-30 16:25:55 +02:00
parent 6e72483617
commit e6627d8d30
2 changed files with 8 additions and 51 deletions

View File

@ -184,12 +184,8 @@ class Completions:
ignore_stream: bool = False, ignore_stream: bool = False,
**kwargs **kwargs
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]: ) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
# We use ModelUtils to obtain the model object. model, provider = get_model_and_provider(
model_instance = ModelUtils.get_model(model) model,
# We receive the model and the provider.
model_name, provider = get_model_and_provider(
model_instance.name, # We use the model name from the object.
self.provider if provider is None else provider, self.provider if provider is None else provider,
stream, stream,
ignored, ignored,
@ -200,8 +196,9 @@ class Completions:
stop = [stop] if isinstance(stop, str) else stop stop = [stop] if isinstance(stop, str) else stop
if asyncio.iscoroutinefunction(provider.create_completion): if asyncio.iscoroutinefunction(provider.create_completion):
# Run the asynchronous function in an event loop
response = asyncio.run(provider.create_completion( response = asyncio.run(provider.create_completion(
model_name, # We use a model based on the object. model,
messages, messages,
stream=stream, stream=stream,
**filter_none( **filter_none(
@ -214,7 +211,7 @@ class Completions:
)) ))
else: else:
response = provider.create_completion( response = provider.create_completion(
model_name, # We use a model from the object. model,
messages, messages,
stream=stream, stream=stream,
**filter_none( **filter_none(
@ -228,19 +225,21 @@ class Completions:
if stream: if stream:
if hasattr(response, '__aiter__'): if hasattr(response, '__aiter__'):
# It's an async generator, wrap it into a sync iterator
response = to_sync_iter(response) response = to_sync_iter(response)
# Now 'response' is an iterator
response = iter_response(response, stream, response_format, max_tokens, stop) response = iter_response(response, stream, response_format, max_tokens, stop)
response = iter_append_model_and_provider(response) response = iter_append_model_and_provider(response)
return response return response
else: else:
if hasattr(response, '__aiter__'): if hasattr(response, '__aiter__'):
# If response is an async generator, collect it into a list
response = list(to_sync_iter(response)) response = list(to_sync_iter(response))
response = iter_response(response, stream, response_format, max_tokens, stop) response = iter_response(response, stream, response_format, max_tokens, stop)
response = iter_append_model_and_provider(response) response = iter_append_model_and_provider(response)
return next(response) return next(response)
async def async_create( async def async_create(
self, self,
messages: Messages, messages: Messages,

View File

@ -891,17 +891,6 @@ any_dark = Model(
) )
class ModelVersions:
# Global Prefixes for All Models
GLOBAL_PREFIXES = [":latest"]
# Specific Prefixes for Particular Models
MODEL_SPECIFIC_PREFIXES = {
#frozenset(["gpt-3.5-turbo", "gpt-4"]): [":custom1", ":custom2"]
#frozenset(["gpt-3.5-turbo"]): [":custom"],
}
class ModelUtils: class ModelUtils:
""" """
Utility class for mapping string identifiers to Model instances. Utility class for mapping string identifiers to Model instances.
@ -1174,35 +1163,4 @@ class ModelUtils:
'any-dark': any_dark, 'any-dark': any_dark,
} }
@classmethod
def get_model(cls, model_name: str) -> Model:
# Checking for specific prefixes
for model_set, specific_prefixes in ModelVersions.MODEL_SPECIFIC_PREFIXES.items():
for prefix in specific_prefixes:
if model_name.endswith(prefix):
base_name = model_name[:-len(prefix)]
if base_name in model_set:
return cls.convert.get(base_name, None)
# Check for global prefixes
for prefix in ModelVersions.GLOBAL_PREFIXES:
if model_name.endswith(prefix):
base_name = model_name[:-len(prefix)]
return cls.convert.get(base_name, None)
# Check without prefix
if model_name in cls.convert:
return cls.convert[model_name]
raise KeyError(f"Model {model_name} not found")
@classmethod
def get_available_versions(cls, model_name: str) -> list[str]:
# Obtaining prefixes for a specific model
prefixes = ModelVersions.GLOBAL_PREFIXES.copy()
for model_set, specific_prefixes in ModelVersions.MODEL_SPECIFIC_PREFIXES.items():
if model_name in model_set:
prefixes.extend(specific_prefixes)
return prefixes
_all_models = list(ModelUtils.convert.keys()) _all_models = list(ModelUtils.convert.keys())