Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LiteLLM configuration for caching and streaming #904

Open
Snikch63200 opened this issue Mar 4, 2025 · 7 comments
Open

LiteLLM configuration for caching and streaming #904

Snikch63200 opened this issue Mar 4, 2025 · 7 comments
Labels
bug Something isn't working question Further information is requested

Comments

@Snikch63200
Copy link

Hi,

I try to configure LiteLLM for enabling caching and streaming with LiteLLM.

# -*- coding: utf-8 -*-

from paperqa.settings import Settings, AgentSettings, AnswerSettings
from paperqa.agents import ask, build_index
from paperqa import Docs, HybridEmbeddingModel,SparseEmbeddingModel,LiteLLMEmbeddingModel, SentenceTransformerEmbeddingModel, NumpyVectorStore, embedding_model_factory
from prompts_french import summary_prompt,summary_json_prompt,qa_prompt, select_paper_prompt, citation_prompt, structured_citation_prompt, default_system_prompt, summary_json_system_prompt
import os
import csv
import litellm
import glob
import compress_pickle

from paperqa.agents.main import agent_query
from paperqa.agents.search import get_directory_index

import asyncio
from datetime import datetime 

from litellm.caching.caching import Cache

litellm.set_verbose=True
litellm.stream=True
litellm.cache=True
litellm.request_timeout = 600

os.environ["OPENAI_API_KEY"] = "sk-no-key-required"
os.environ['MISTRAL_API_KEY'] = "sk-no-key-required"

local_llm_config = dict(
                        model_list=[
                                    dict(
                                        model_name="ollama/dolphin-mixtral:8x7b",
                                        litellm_params=dict(
                                                            model="ollama/dolphin-mixtral:8x7b",
                                                            api_type= "ollama",
                                                            api_base="http://localhost:11434",
                                                            stream=True,
                                                            caching=True,
                                                            cache= True,
#                                                            format='json',
                                                            mirostat=0,
                                                            mirostat_eta=0.1,
                                                            mirostat_tau=5.0,
                                                            num_ctx=8192, #Default 2048
                                                            repeat_last_n=64, # Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
                                                            repeat_penalty=1.1,
                                                            temperature=0.5,
                                                            seed=0,
                                                            tfs_z=1,
                                                            num_predict=8192,# Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)
                                                            max_tokens=8192,
#                                                            low_vram=True,
                                                            top_k=20,
                                                            top_p=0.9,
                                                            min_p=0.0,
                                                            timeout = 600,
                                                            ),
                                        ),

                                    ]
                        )

'stream' parameter seems unrecognized as kwarg by litellm.

litellm.cache=True leads to an error that indicates this arg cannot be a bool... Also tried to pass a Cache() object but leads to an error...

What is the correct way to enable streaming and caching correctly ?

Copy link

dosubot bot commented Mar 4, 2025

To correctly enable streaming and caching in the litellm library, you need to adjust the parameters as follows:

  1. Streaming: The stream parameter should be set within the litellm_params dictionary for the specific model configuration. Ensure that it is correctly recognized by the library. If it is not recognized, it might be due to a version mismatch or incorrect usage. Double-check the library's documentation or source code to confirm the correct parameter name and usage.

  2. Caching: Instead of setting litellm.cache=True, you should instantiate a Cache() object and assign it to litellm.cache. Here's an example of how to do this:

    from litellm.caching.caching import Cache
    
    # Create a Cache object
    litellm.cache = Cache()

    Ensure that the Cache() object is correctly imported and used. This should resolve the error related to the cache parameter not accepting a boolean value.

By following these steps, you should be able to configure litellm with caching and streaming enabled correctly. If issues persist, consider checking for any updates or changes in the library's API that might affect these configurations.

To continue talking to Dosu, mention @dosu.


Help Dosu learn! Give it feedback: Great Response | Irrelevant Answer | Incorrect Sources | Too Verbose | Hallucination | Bug Report | Other

@dosubot dosubot bot added bug Something isn't working question Further information is requested labels Mar 4, 2025
@Snikch63200
Copy link
Author

  • 'stream' parameter is ever set on 'True' in litellm_params but I can see it's still False in LiteLLM verbose mode.
  • passed Cache() object litellm.cache = Cache() and got this error :
[12:47:07] Trajectory failed.
           ╭──────────────────────────────────── Traceback (most recent call last) ────────────────────────────────────╮
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\paperqa\agents\main.py:158 in                 │
           │ _run_with_timeout_failure                                                                                 │
           │                                                                                                           │
           │   155 ) -> tuple[PQASession, AgentStatus]:                                                                │
           │   156 │   try:                                                                                            │
           │   157 │   │   async with asyncio.timeout(settings.agent.timeout):                                         │
           │ ❱ 158 │   │   │   status = await rollout()                                                                │
           │   159 │   except TimeoutError:                                                                            │
           │   160 │   │   logger.warning(                                                                             │
           │   161 │   │   │   f"Agent timeout after {settings.agent.timeout}-sec, just answering."                    │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\paperqa\agents\main.py:309 in rollout         │
           │                                                                                                           │
           │   306 │   │   │   │   )                                                                                   │
           │   307 │   │   │   │   return AgentStatus.TRUNCATED                                                        │
           │   308 │   │   │   agent_state.messages += obs                                                             │
           │ ❱ 309 │   │   │   for attempt in Retrying(                                                                │
           │   310 │   │   │   │   stop=stop_after_attempt(5),                                                         │
           │   311 │   │   │   │   retry=retry_if_exception_type(MalformedMessageError),                               │
           │   312 │   │   │   │   before_sleep=before_sleep_log(logger, logging.WARNING),                             │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\tenacity\__init__.py:443 in __iter__          │
           │                                                                                                           │
           │   440 │   │                                                                                               │
           │   441 │   │   retry_state = RetryCallState(self, fn=None, args=(), kwargs={})                             │
           │   442 │   │   while True:                                                                                 │
           │ ❱ 443 │   │   │   do = self.iter(retry_state=retry_state)                                                 │
           │   444 │   │   │   if isinstance(do, DoAttempt):                                                           │
           │   445 │   │   │   │   yield AttemptManager(retry_state=retry_state)                                       │
           │   446 │   │   │   elif isinstance(do, DoSleep):                                                           │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\tenacity\__init__.py:376 in iter              │
           │                                                                                                           │
           │   373 │   │   self._begin_iter(retry_state)                                                               │
           │   374 │   │   result = None                                                                               │
           │   375 │   │   for action in self.iter_state.actions:                                                      │
           │ ❱ 376 │   │   │   result = action(retry_state)                                                            │
           │   377 │   │   return result                                                                               │
           │   378 │                                                                                                   │
           │   379 │   def _begin_iter(self, retry_state: "RetryCallState") -> None:  # noqa                           │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\tenacity\__init__.py:398 in <lambda>          │
           │                                                                                                           │
           │   395 │                                                                                                   │
           │   396 │   def _post_retry_check_actions(self, retry_state: "RetryCallState") -> None:                     │
           │   397 │   │   if not (self.iter_state.is_explicit_retry or self.iter_state.retry_run_result):             │
           │ ❱ 398 │   │   │   self._add_action_func(lambda rs: rs.outcome.result())                                   │
           │   399 │   │   │   return                                                                                  │
           │   400 │   │                                                                                               │
           │   401 │   │   if self.after is not None:                                                                  │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\concurrent\futures\_base.py:449 in result                   │
           │                                                                                                           │
           │   446 │   │   │   │   if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:                              │
           │   447 │   │   │   │   │   raise CancelledError()                                                          │
           │   448 │   │   │   │   elif self._state == FINISHED:                                                       │
           │ ❱ 449 │   │   │   │   │   return self.__get_result()                                                      │
           │   450 │   │   │   │                                                                                       │
           │   451 │   │   │   │   self._condition.wait(timeout)                                                       │
           │   452                                                                                                     │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\concurrent\futures\_base.py:401 in __get_result             │
           │                                                                                                           │
           │   398 │   def __get_result(self):                                                                         │
           │   399 │   │   if self._exception:                                                                         │
           │   400 │   │   │   try:                                                                                    │
           │ ❱ 401 │   │   │   │   raise self._exception                                                               │
           │   402 │   │   │   finally:                                                                                │
           │   403 │   │   │   │   # Break a reference cycle with the exception in self._exception                     │
           │   404 │   │   │   │   self = None                                                                         │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\paperqa\agents\main.py:316 in rollout         │
           │                                                                                                           │
           │   313 │   │   │   │   reraise=True,                                                                       │
           │   314 │   │   │   ):                                                                                      │
           │   315 │   │   │   │   with attempt:  # Retrying if ToolSelector fails to select a tool                    │
           │ ❱ 316 │   │   │   │   │   action = await agent(agent_state.messages, tools)                               │
           │   317 │   │   │   agent_state.messages = [*agent_state.messages, action]                                  │
           │   318 │   │   │   if on_agent_action_callback:                                                            │
           │   319 │   │   │   │   await on_agent_action_callback(action, agent_state)                                 │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\aviary\tools\utils.py:84 in __call__          │
           │                                                                                                           │
           │    81 │   │   │   self._ledger.messages.extend(messages)                                                  │
           │    82 │   │   │   messages = self._ledger.messages                                                        │
           │    83 │   │                                                                                               │
           │ ❱  84 │   │   model_response = await self._bound_acompletion(                                             │
           │    85 │   │   │   messages=MessagesAdapter.dump_python(                                                   │
           │    86 │   │   │   │   messages, exclude_none=True, by_alias=True                                          │
           │    87 │   │   │   ),                                                                                      │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:907 in acompletion          │
           │                                                                                                           │
           │    904 │   │   │   │   │   original_exception=e,                                                          │
           │    905 │   │   │   │   )                                                                                  │
           │    906 │   │   │   )                                                                                      │
           │ ❱  907 │   │   │   raise e                                                                                │
           │    908 │                                                                                                  │
           │    909 │   async def _acompletion(                                                                        │
           │    910 │   │   self, model: str, messages: List[Dict[str, str]], **kwargs                                 │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:883 in acompletion          │
           │                                                                                                           │
           │    880 │   │   │   if request_priority is not None and isinstance(request_priority, int):                 │
           │    881 │   │   │   │   response = await self.schedule_acompletion(**kwargs)                               │
           │    882 │   │   │   else:                                                                                  │
           │ ❱  883 │   │   │   │   response = await self.async_function_with_fallbacks(**kwargs)                      │
           │    884 │   │   │   end_time = time.time()                                                                 │
           │    885 │   │   │   _duration = end_time - start_time                                                      │
           │    886 │   │   │   asyncio.create_task(                                                                   │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:3079 in                     │
           │ async_function_with_fallbacks                                                                             │
           │                                                                                                           │
           │   3076 │   │   │   │   │   │   )                                                                          │
           │   3077 │   │   │   │   │   )                                                                              │
           │   3078 │   │   │                                                                                          │
           │ ❱ 3079 │   │   │   raise original_exception                                                               │
           │   3080 │                                                                                                  │
           │   3081 │   def _handle_mock_testing_fallbacks(                                                            │
           │   3082 │   │   self,                                                                                      │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:2893 in                     │
           │ async_function_with_fallbacks                                                                             │
           │                                                                                                           │
           │   2890 │   │   │   │   │   *args, **kwargs, mock_timeout=mock_timeout                                     │
           │   2891 │   │   │   │   )                                                                                  │
           │   2892 │   │   │   else:                                                                                  │
           │ ❱ 2893 │   │   │   │   response = await self.async_function_with_retries(*args, **kwargs)                 │
           │   2894 │   │   │   verbose_router_logger.debug(f"Async Response: {response}")                             │
           │   2895 │   │   │   response = add_fallback_headers_to_response(                                           │
           │   2896 │   │   │   │   response=response,                                                                 │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:3269 in                     │
           │ async_function_with_retries                                                                               │
           │                                                                                                           │
           │   3266 │   │   │   │   setattr(original_exception, "max_retries", num_retries)                            │
           │   3267 │   │   │   │   setattr(original_exception, "num_retries", current_attempt)                        │
           │   3268 │   │   │                                                                                          │
           │ ❱ 3269 │   │   │   raise original_exception                                                               │
           │   3270 │                                                                                                  │
           │   3271 │   async def make_call(self, original_function: Any, *args, **kwargs):                            │
           │   3272 │   │   """                                                                                        │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:3162 in                     │
           │ async_function_with_retries                                                                               │
           │                                                                                                           │
           │   3159 │   │   │   │   model_group=model_group, kwargs=kwargs                                             │
           │   3160 │   │   │   )                                                                                      │
           │   3161 │   │   │   # if the function call is successful, no exception will be raised and we'll            │
           │        break out of the loop                                                                              │
           │ ❱ 3162 │   │   │   response = await self.make_call(original_function, *args, **kwargs)                    │
           │   3163 │   │   │   response = add_retry_headers_to_response(                                              │
           │   3164 │   │   │   │   response=response, attempted_retries=0, max_retries=None                           │
           │   3165 │   │   │   )                                                                                      │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:3278 in make_call           │
           │                                                                                                           │
           │   3275 │   │   model_group = kwargs.get("model")                                                          │
           │   3276 │   │   response = original_function(*args, **kwargs)                                              │
           │   3277 │   │   if inspect.iscoroutinefunction(response) or inspect.isawaitable(response):                 │
           │ ❱ 3278 │   │   │   response = await response                                                              │
           │   3279 │   │   ## PROCESS RESPONSE HEADERS                                                                │
           │   3280 │   │   response = await self.set_response_headers(                                                │
           │   3281 │   │   │   response=response, model_group=model_group                                             │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:1045 in _acompletion        │
           │                                                                                                           │
           │   1042 │   │   │   )                                                                                      │
           │   1043 │   │   │   if model_name is not None:                                                             │
           │   1044 │   │   │   │   self.fail_calls[model_name] += 1                                                   │
           │ ❱ 1045 │   │   │   raise e                                                                                │
           │   1046 │                                                                                                  │
           │   1047 │   def _update_kwargs_before_fallbacks(self, model: str, kwargs: dict) -> None:                   │
           │   1048 │   │   """                                                                                        │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\router.py:1004 in _acompletion        │
           │                                                                                                           │
           │   1001 │   │   │   │   │   parent_otel_span=parent_otel_span,                                             │
           │   1002 │   │   │   │   )                                                                                  │
           │   1003 │   │   │   │                                                                                      │
           │ ❱ 1004 │   │   │   │   response = await _response                                                         │
           │   1005 │   │   │                                                                                          │
           │   1006 │   │   │   ## CHECK CONTENT FILTER ERROR ##                                                       │
           │   1007 │   │   │   if isinstance(response, ModelResponse):                                                │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\utils.py:1397 in wrapper_async        │
           │                                                                                                           │
           │   1394 │   │   │                                                                                          │
           │   1395 │   │   │   timeout = _get_wrapper_timeout(kwargs=kwargs, exception=e)                             │
           │   1396 │   │   │   setattr(e, "timeout", timeout)                                                         │
           │ ❱ 1397 │   │   │   raise e                                                                                │
           │   1398 │                                                                                                  │
           │   1399 │   is_coroutine = inspect.iscoroutinefunction(original_function)                                  │
           │   1400                                                                                                    │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\utils.py:1287 in wrapper_async        │
           │                                                                                                           │
           │   1284 │   │   │   )                                                                                      │
           │   1285 │   │   │                                                                                          │
           │   1286 │   │   │   ## Add response to cache                                                               │
           │ ❱ 1287 │   │   │   await _llm_caching_handler.async_set_cache(                                            │
           │   1288 │   │   │   │   result=result,                                                                     │
           │   1289 │   │   │   │   original_function=original_function,                                               │
           │   1290 │   │   │   │   kwargs=kwargs,                                                                     │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\caching\caching_handler.py:684 in     │
           │ async_set_cache                                                                                           │
           │                                                                                                           │
           │   681 │   │   │   )                                                                                       │
           │   682 │   │   )                                                                                           │
           │   683 │   │   # [OPTIONAL] ADD TO CACHE                                                                   │
           │ ❱ 684 │   │   if self._should_store_result_in_cache(                                                      │
           │   685 │   │   │   original_function=original_function, kwargs=new_kwargs                                  │
           │   686 │   │   ):                                                                                          │
           │   687 │   │   │   if (                                                                                    │
           │                                                                                                           │
           │ C:\ProgramData\anaconda3\envs\paperQA_env\Lib\site-packages\litellm\caching\caching_handler.py:758 in     │
           │ _should_store_result_in_cache                                                                             │
           │                                                                                                           │
           │   755 │   │   │   (litellm.cache is not None)                                                             │
           │   756 │   │   │   and litellm.cache.supported_call_types is not None                                      │
           │   757 │   │   │   and (str(original_function.__name__) in litellm.cache.supported_call_types)             │
           │ ❱ 758 │   │   │   and (kwargs.get("cache", {}).get("no-store", False) is not True)                        │
           │   759 │   │   )                                                                                           │
           │   760 │                                                                                                   │
           │   761 │   def _is_call_type_supported_by_cache(                                                           │
           ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────╯
           AttributeError: 'bool' object has no attribute 'get'

Any solution for these problems ?

Copy link

dosubot bot commented Mar 4, 2025

To address the issues you're encountering with the LiteLLM library:

  1. Streaming Parameter: Ensure that the stream parameter is correctly set within the litellm_params dictionary. If it appears as False in verbose mode despite being set to True, it might be due to an incorrect configuration or a bug in the library. Double-check that the parameter is correctly spelled and placed within the dictionary. If the issue persists, consider checking the library's documentation or source code for any updates or known issues related to this parameter.

  2. Caching Configuration: The error you're encountering (AttributeError: 'bool' object has no attribute 'get') suggests that the cache parameter is being treated as a boolean instead of a dictionary. Ensure that you are correctly instantiating the Cache object and assigning it to litellm.cache. Here's how you should do it:

    from litellm.caching.caching import Cache
    
    # Create and assign a Cache object
    litellm.cache = Cache()

    Make sure that the Cache object is properly imported and used. If the error persists, it might be due to an incorrect implementation or a bug in the library's caching mechanism. You may need to review the library's documentation or source code for further guidance on the correct usage of the caching feature.

By ensuring these configurations are correctly set, you should be able to resolve the issues with streaming and caching in the LiteLLM library. If problems continue, consider checking for any updates or changes in the library's API that might affect these configurations.

To continue talking to Dosu, mention @dosu.


Help Dosu learn! Give it feedback: Great Response | Irrelevant Answer | Incorrect Sources | Too Verbose | Hallucination | Bug Report | Other

@Snikch63200
Copy link
Author

I guess Cache() object is not correctly instantiated. Is there any documentation for that ?

@maykcaldas
Copy link
Collaborator

Hello @Snikch63200

Please notice that the LLMs we use in paperqa comes from lmi, which is a wrapper over litellm. Of course the usage might be very similar, but there are a few differences.

That said, I am assuming you want to locally cache, is that right? This page shows how to do it with litellm.

Please check this code snippet to get it done in lmi:

from lmi import LiteLLMModel
from aviary.core import Message
import litellm
from litellm.caching.caching import Cache

# Set up caching -- I did not need this setup. But since it's in the litellm page, I left it here
# litellm.cache = Cache()

llm = LiteLLMModel(name="gpt-4o-mini")  # Change it to your llm

r = await llm.call_single(
    messages=[Message(role="user", content="Tell me a very very very random joke")],
    caching=True
)

print(r.text)

I could rerun it multiple times and get the same output. Let me know if it works for you.

@Snikch63200
Copy link
Author

Hello @maykcaldas

Thanks for your answer. Your snippet perfectly works with llama3.3.

So, I think there could be two explainations :

  • PQA overrides these parameters in code
  • parameter are not passed by LiteLLM API ('stream' appears as 'False' in lite llm verbose mode but not 'caching')

Best regards.

@maykcaldas
Copy link
Collaborator

Hey @Snikch63200

I don't think PQA uses caching anywhere. But you're right about streaming.
PQA uses LLMs from lmi. On lmi, you can see an example here on how we override the stream option depending on the use case.

Could you tell me why do you need the answer to be streamed back to you? In our ecosystem, we use lmi to abstract all the LLM operations and PQA need to deal with the curated LLMResults object only.

Best,

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants