Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						046eafc
	
1
								Parent(s):
							
							56980e9
								
fix secrets handling
Browse files
    	
        app.py
    CHANGED
    
    | @@ -50,6 +50,8 @@ if "verbose" not in st.session_state: | |
| 50 | 
             
                st.session_state.verbose = verbose
         | 
| 51 | 
             
            if "max_tokens" not in st.session_state:
         | 
| 52 | 
             
                st.session_state.max_tokens = max_tokens
         | 
|  | |
|  | |
| 53 | 
             
            if "temperature" not in st.session_state:
         | 
| 54 | 
             
                st.session_state.temperature = temperature
         | 
| 55 | 
             
            if "next_prompts" not in st.session_state:
         | 
| @@ -272,6 +274,8 @@ try: | |
| 272 | 
             
                            num_turns=st.session_state.num_turns,
         | 
| 273 | 
             
                            temperature=st.session_state.temperature,
         | 
| 274 | 
             
                            max_tokens=st.session_state.max_tokens,
         | 
|  | |
|  | |
| 275 | 
             
                            verbose=st.session_state.verbose,
         | 
| 276 | 
             
                        )
         | 
| 277 | 
             
                    chunk = next(st.session_state.generator)
         | 
|  | |
| 50 | 
             
                st.session_state.verbose = verbose
         | 
| 51 | 
             
            if "max_tokens" not in st.session_state:
         | 
| 52 | 
             
                st.session_state.max_tokens = max_tokens
         | 
| 53 | 
            +
            if "seed" not in st.session_state:
         | 
| 54 | 
            +
                st.session_state.seed = 0
         | 
| 55 | 
             
            if "temperature" not in st.session_state:
         | 
| 56 | 
             
                st.session_state.temperature = temperature
         | 
| 57 | 
             
            if "next_prompts" not in st.session_state:
         | 
|  | |
| 274 | 
             
                            num_turns=st.session_state.num_turns,
         | 
| 275 | 
             
                            temperature=st.session_state.temperature,
         | 
| 276 | 
             
                            max_tokens=st.session_state.max_tokens,
         | 
| 277 | 
            +
                            seed=st.session_state.seed,
         | 
| 278 | 
            +
                            secrets=st.session_state.secrets,
         | 
| 279 | 
             
                            verbose=st.session_state.verbose,
         | 
| 280 | 
             
                        )
         | 
| 281 | 
             
                    chunk = next(st.session_state.generator)
         | 
    	
        cli.py
    CHANGED
    
    | @@ -1,4 +1,5 @@ | |
| 1 | 
             
            import argparse
         | 
|  | |
| 2 | 
             
            import time
         | 
| 3 |  | 
| 4 | 
             
            from src.open_strawberry import get_defaults, manage_conversation
         | 
| @@ -54,6 +55,7 @@ def go_cli(): | |
| 54 | 
             
                                                temperature=args.temperature,
         | 
| 55 | 
             
                                                max_tokens=args.max_tokens,
         | 
| 56 | 
             
                                                seed=args.seed,
         | 
|  | |
| 57 | 
             
                                                cli_mode=True)
         | 
| 58 | 
             
                response = ''
         | 
| 59 | 
             
                conversation_history = []
         | 
|  | |
| 1 | 
             
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
             
            import time
         | 
| 4 |  | 
| 5 | 
             
            from src.open_strawberry import get_defaults, manage_conversation
         | 
|  | |
| 55 | 
             
                                                temperature=args.temperature,
         | 
| 56 | 
             
                                                max_tokens=args.max_tokens,
         | 
| 57 | 
             
                                                seed=args.seed,
         | 
| 58 | 
            +
                                                secrets=dict(os.environ),
         | 
| 59 | 
             
                                                cli_mode=True)
         | 
| 60 | 
             
                response = ''
         | 
| 61 | 
             
                conversation_history = []
         | 
    	
        models.py
    CHANGED
    
    | @@ -25,6 +25,7 @@ def get_anthropic(model: str, | |
| 25 | 
             
                              max_tokens: int = 4096,
         | 
| 26 | 
             
                              system: str = '',
         | 
| 27 | 
             
                              chat_history: List[Dict] = None,
         | 
|  | |
| 28 | 
             
                              verbose=False) -> \
         | 
| 29 | 
             
                    Generator[dict, None, None]:
         | 
| 30 | 
             
                model = model.replace('anthropic:', '')
         | 
| @@ -32,7 +33,7 @@ def get_anthropic(model: str, | |
| 32 | 
             
                # https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
         | 
| 33 | 
             
                import anthropic
         | 
| 34 |  | 
| 35 | 
            -
                clawd_key =  | 
| 36 | 
             
                clawd_client = anthropic.Anthropic(api_key=clawd_key) if clawd_key else None
         | 
| 37 |  | 
| 38 | 
             
                if chat_history is None:
         | 
| @@ -118,16 +119,16 @@ def get_openai(model: str, | |
| 118 | 
             
                           max_tokens: int = 4096,
         | 
| 119 | 
             
                           system: str = '',
         | 
| 120 | 
             
                           chat_history: List[Dict] = None,
         | 
|  | |
| 121 | 
             
                           verbose=False) -> Generator[dict, None, None]:
         | 
| 122 | 
            -
                 | 
| 123 | 
            -
                if model in ollama:
         | 
| 124 | 
             
                    model = model.replace('ollama:', '')
         | 
| 125 | 
            -
                    openai_key =  | 
| 126 | 
            -
                    openai_base_url =  | 
| 127 | 
             
                else:
         | 
| 128 | 
             
                    model = model.replace('openai:', '')
         | 
| 129 | 
            -
                    openai_key =  | 
| 130 | 
            -
                    openai_base_url =  | 
| 131 |  | 
| 132 | 
             
                from openai import OpenAI
         | 
| 133 |  | 
| @@ -206,12 +207,13 @@ def get_google(model: str, | |
| 206 | 
             
                           max_tokens: int = 4096,
         | 
| 207 | 
             
                           system: str = '',
         | 
| 208 | 
             
                           chat_history: List[Dict] = None,
         | 
|  | |
| 209 | 
             
                           verbose=False) -> Generator[dict, None, None]:
         | 
| 210 | 
             
                model = model.replace('google:', '').replace('gemini:', '')
         | 
| 211 |  | 
| 212 | 
             
                import google.generativeai as genai
         | 
| 213 |  | 
| 214 | 
            -
                gemini_key =  | 
| 215 | 
             
                genai.configure(api_key=gemini_key)
         | 
| 216 | 
             
                # Create the model
         | 
| 217 | 
             
                generation_config = {
         | 
| @@ -308,12 +310,13 @@ def get_groq(model: str, | |
| 308 | 
             
                         max_tokens: int = 4096,
         | 
| 309 | 
             
                         system: str = '',
         | 
| 310 | 
             
                         chat_history: List[Dict] = None,
         | 
|  | |
| 311 | 
             
                         verbose=False) -> Generator[dict, None, None]:
         | 
| 312 | 
             
                model = model.replace('groq:', '')
         | 
| 313 |  | 
| 314 | 
             
                from groq import Groq
         | 
| 315 |  | 
| 316 | 
            -
                groq_key =  | 
| 317 | 
             
                client = Groq(api_key=groq_key)
         | 
| 318 |  | 
| 319 | 
             
                if chat_history is None:
         | 
| @@ -352,15 +355,16 @@ def get_openai_azure(model: str, | |
| 352 | 
             
                                 max_tokens: int = 4096,
         | 
| 353 | 
             
                                 system: str = '',
         | 
| 354 | 
             
                                 chat_history: List[Dict] = None,
         | 
|  | |
| 355 | 
             
                                 verbose=False) -> Generator[dict, None, None]:
         | 
| 356 | 
             
                model = model.replace('azure:', '').replace('openai_azure:', '')
         | 
| 357 |  | 
| 358 | 
             
                from openai import AzureOpenAI
         | 
| 359 |  | 
| 360 | 
            -
                azure_endpoint =  | 
| 361 | 
            -
                azure_key =  | 
| 362 | 
            -
                azure_deployment =  | 
| 363 | 
            -
                azure_api_version =  | 
| 364 | 
             
                assert azure_endpoint is not None, "Azure OpenAI endpoint not set"
         | 
| 365 | 
             
                assert azure_key is not None, "Azure OpenAI API key not set"
         | 
| 366 | 
             
                assert azure_deployment is not None, "Azure OpenAI deployment not set"
         | 
| @@ -420,15 +424,15 @@ def get_model_names(secrets, on_hf_spaces=False): | |
| 420 | 
             
                else:
         | 
| 421 | 
             
                    anthropic_models = []
         | 
| 422 | 
             
                if secrets.get('OPENAI_API_KEY'):
         | 
| 423 | 
            -
                    if  | 
| 424 | 
            -
                        openai_models = to_list( | 
| 425 | 
             
                    else:
         | 
| 426 | 
             
                        openai_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
         | 
| 427 | 
             
                else:
         | 
| 428 | 
             
                    openai_models = []
         | 
| 429 | 
             
                if secrets.get('AZURE_OPENAI_API_KEY'):
         | 
| 430 | 
            -
                    if  | 
| 431 | 
            -
                        azure_models = to_list( | 
| 432 | 
             
                    else:
         | 
| 433 | 
             
                        azure_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
         | 
| 434 | 
             
                else:
         | 
|  | |
| 25 | 
             
                              max_tokens: int = 4096,
         | 
| 26 | 
             
                              system: str = '',
         | 
| 27 | 
             
                              chat_history: List[Dict] = None,
         | 
| 28 | 
            +
                              secrets: Dict = {},
         | 
| 29 | 
             
                              verbose=False) -> \
         | 
| 30 | 
             
                    Generator[dict, None, None]:
         | 
| 31 | 
             
                model = model.replace('anthropic:', '')
         | 
|  | |
| 33 | 
             
                # https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
         | 
| 34 | 
             
                import anthropic
         | 
| 35 |  | 
| 36 | 
            +
                clawd_key = secrets.get('ANTHROPIC_API_KEY')
         | 
| 37 | 
             
                clawd_client = anthropic.Anthropic(api_key=clawd_key) if clawd_key else None
         | 
| 38 |  | 
| 39 | 
             
                if chat_history is None:
         | 
|  | |
| 119 | 
             
                           max_tokens: int = 4096,
         | 
| 120 | 
             
                           system: str = '',
         | 
| 121 | 
             
                           chat_history: List[Dict] = None,
         | 
| 122 | 
            +
                           secrets: Dict = {},
         | 
| 123 | 
             
                           verbose=False) -> Generator[dict, None, None]:
         | 
| 124 | 
            +
                if model.startswith('ollama:'):
         | 
|  | |
| 125 | 
             
                    model = model.replace('ollama:', '')
         | 
| 126 | 
            +
                    openai_key = secrets.get('OLLAMA_OPENAI_API_KEY')
         | 
| 127 | 
            +
                    openai_base_url = secrets.get('OLLAMA_OPENAI_BASE_URL', 'http://localhost:11434/v1/')
         | 
| 128 | 
             
                else:
         | 
| 129 | 
             
                    model = model.replace('openai:', '')
         | 
| 130 | 
            +
                    openai_key = secrets.get('OPENAI_API_KEY')
         | 
| 131 | 
            +
                    openai_base_url = secrets.get('OPENAI_BASE_URL', 'https://api.openai.com/v1')
         | 
| 132 |  | 
| 133 | 
             
                from openai import OpenAI
         | 
| 134 |  | 
|  | |
| 207 | 
             
                           max_tokens: int = 4096,
         | 
| 208 | 
             
                           system: str = '',
         | 
| 209 | 
             
                           chat_history: List[Dict] = None,
         | 
| 210 | 
            +
                           secrets: Dict = {},
         | 
| 211 | 
             
                           verbose=False) -> Generator[dict, None, None]:
         | 
| 212 | 
             
                model = model.replace('google:', '').replace('gemini:', '')
         | 
| 213 |  | 
| 214 | 
             
                import google.generativeai as genai
         | 
| 215 |  | 
| 216 | 
            +
                gemini_key = secrets.get("GEMINI_API_KEY")
         | 
| 217 | 
             
                genai.configure(api_key=gemini_key)
         | 
| 218 | 
             
                # Create the model
         | 
| 219 | 
             
                generation_config = {
         | 
|  | |
| 310 | 
             
                         max_tokens: int = 4096,
         | 
| 311 | 
             
                         system: str = '',
         | 
| 312 | 
             
                         chat_history: List[Dict] = None,
         | 
| 313 | 
            +
                         secrets: Dict = {},
         | 
| 314 | 
             
                         verbose=False) -> Generator[dict, None, None]:
         | 
| 315 | 
             
                model = model.replace('groq:', '')
         | 
| 316 |  | 
| 317 | 
             
                from groq import Groq
         | 
| 318 |  | 
| 319 | 
            +
                groq_key = secrets.get("GROQ_API_KEY")
         | 
| 320 | 
             
                client = Groq(api_key=groq_key)
         | 
| 321 |  | 
| 322 | 
             
                if chat_history is None:
         | 
|  | |
| 355 | 
             
                                 max_tokens: int = 4096,
         | 
| 356 | 
             
                                 system: str = '',
         | 
| 357 | 
             
                                 chat_history: List[Dict] = None,
         | 
| 358 | 
            +
                                 secrets: Dict = {},
         | 
| 359 | 
             
                                 verbose=False) -> Generator[dict, None, None]:
         | 
| 360 | 
             
                model = model.replace('azure:', '').replace('openai_azure:', '')
         | 
| 361 |  | 
| 362 | 
             
                from openai import AzureOpenAI
         | 
| 363 |  | 
| 364 | 
            +
                azure_endpoint = secrets.get("AZURE_OPENAI_ENDPOINT")  # e.g. https://project.openai.azure.com
         | 
| 365 | 
            +
                azure_key = secrets.get("AZURE_OPENAI_API_KEY")
         | 
| 366 | 
            +
                azure_deployment = secrets.get("AZURE_OPENAI_DEPLOYMENT")  # i.e. deployment name with some models deployed
         | 
| 367 | 
            +
                azure_api_version = secrets.get('AZURE_OPENAI_API_VERSION', '2024-07-01-preview')
         | 
| 368 | 
             
                assert azure_endpoint is not None, "Azure OpenAI endpoint not set"
         | 
| 369 | 
             
                assert azure_key is not None, "Azure OpenAI API key not set"
         | 
| 370 | 
             
                assert azure_deployment is not None, "Azure OpenAI deployment not set"
         | 
|  | |
| 424 | 
             
                else:
         | 
| 425 | 
             
                    anthropic_models = []
         | 
| 426 | 
             
                if secrets.get('OPENAI_API_KEY'):
         | 
| 427 | 
            +
                    if secrets.get('OPENAI_MODEL_NAME'):
         | 
| 428 | 
            +
                        openai_models = to_list(secrets.get('OPENAI_MODEL_NAME'))
         | 
| 429 | 
             
                    else:
         | 
| 430 | 
             
                        openai_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
         | 
| 431 | 
             
                else:
         | 
| 432 | 
             
                    openai_models = []
         | 
| 433 | 
             
                if secrets.get('AZURE_OPENAI_API_KEY'):
         | 
| 434 | 
            +
                    if secrets.get('AZURE_OPENAI_MODEL_NAME'):
         | 
| 435 | 
            +
                        azure_models = to_list(secrets.get('AZURE_OPENAI_MODEL_NAME'))
         | 
| 436 | 
             
                    else:
         | 
| 437 | 
             
                        azure_models = ['gpt-4o', 'gpt-4-turbo-2024-04-09', 'gpt-4o-mini']
         | 
| 438 | 
             
                else:
         | 
    	
        open_strawberry.py
    CHANGED
    
    | @@ -290,7 +290,8 @@ def manage_conversation(model: str, | |
| 290 | 
             
                                    temperature: float = 0.3,
         | 
| 291 | 
             
                                    max_tokens: int = 4096,
         | 
| 292 | 
             
                                    seed: int = 1234,
         | 
| 293 | 
            -
                                     | 
|  | |
| 294 | 
             
                                    ) -> Generator[Dict, None, list]:
         | 
| 295 | 
             
                if seed == 0:
         | 
| 296 | 
             
                    seed = random.randint(0, 1000000)
         | 
| @@ -344,7 +345,9 @@ def manage_conversation(model: str, | |
| 344 | 
             
                    thinking_time = time.time()
         | 
| 345 | 
             
                    response_text = ''
         | 
| 346 | 
             
                    for chunk in get_model_func(model, prompt, system=system, chat_history=chat_history,
         | 
| 347 | 
            -
                                                temperature=temperature, max_tokens=max_tokens, | 
|  | |
|  | |
| 348 | 
             
                        if 'text' in chunk and chunk['text']:
         | 
| 349 | 
             
                            response_text += chunk['text']
         | 
| 350 | 
             
                            yield {"role": "assistant", "content": chunk['text'], "streaming": True, "chat_history": chat_history,
         | 
|  | |
| 290 | 
             
                                    temperature: float = 0.3,
         | 
| 291 | 
             
                                    max_tokens: int = 4096,
         | 
| 292 | 
             
                                    seed: int = 1234,
         | 
| 293 | 
            +
                                    secrets: Dict = {},
         | 
| 294 | 
            +
                                    verbose: bool = False,
         | 
| 295 | 
             
                                    ) -> Generator[Dict, None, list]:
         | 
| 296 | 
             
                if seed == 0:
         | 
| 297 | 
             
                    seed = random.randint(0, 1000000)
         | 
|  | |
| 345 | 
             
                    thinking_time = time.time()
         | 
| 346 | 
             
                    response_text = ''
         | 
| 347 | 
             
                    for chunk in get_model_func(model, prompt, system=system, chat_history=chat_history,
         | 
| 348 | 
            +
                                                temperature=temperature, max_tokens=max_tokens,
         | 
| 349 | 
            +
                                                secrets=secrets,
         | 
| 350 | 
            +
                                                verbose=verbose):
         | 
| 351 | 
             
                        if 'text' in chunk and chunk['text']:
         | 
| 352 | 
             
                            response_text += chunk['text']
         | 
| 353 | 
             
                            yield {"role": "assistant", "content": chunk['text'], "streaming": True, "chat_history": chat_history,
         | 
