Fixed error handling
Browse files- L3Score.py +29 -28
    	
        L3Score.py
    CHANGED
    
    | @@ -123,35 +123,28 @@ class L3Score(evaluate.Metric): | |
| 123 |  | 
| 124 | 
             
                    # Check whether the model is available
         | 
| 125 | 
             
                    print(provider,model)
         | 
| 126 | 
            -
                     | 
| 127 | 
            -
             | 
| 128 | 
            -
             | 
| 129 | 
            -
             | 
| 130 | 
            -
             | 
| 131 | 
            -
             | 
| 132 | 
            -
                            
         | 
| 133 | 
            -
                        elif provider == "deepseek":
         | 
| 134 | 
            -
                            print("Checking DeepSeek model")
         | 
| 135 | 
            -
                            client = openai.OpenAI(api_key=api_key,base_url="https://api.deepseek.com")
         | 
| 136 | 
            -
                            model_names = [model.id for model in client.models.list()]
         | 
| 137 | 
            -
                            print(model_names)  
         | 
| 138 | 
            -
                            if model not in model_names:
         | 
| 139 | 
            -
                                raise ValueError(f"Model {model} not found for provider {provider}, available models: {model_names}")
         | 
| 140 |  | 
| 141 | 
            -
             | 
| 142 | 
            -
             | 
| 143 | 
            -
             | 
| 144 | 
            -
             | 
| 145 | 
            -
             | 
| 146 | 
            -
             | 
|  | |
| 147 |  | 
| 148 | 
            -
                     | 
| 149 | 
            -
                         | 
| 150 | 
            -
             | 
| 151 | 
            -
                         | 
| 152 | 
            -
                         | 
| 153 | 
            -
             | 
| 154 | 
            -
             | 
| 155 |  | 
| 156 | 
             
                    assert len(questions) == len(predictions) == len(references), "Questions, predictions and references must have the same length"
         | 
| 157 |  | 
| @@ -175,7 +168,15 @@ class L3Score(evaluate.Metric): | |
| 175 | 
             
                    """Returns the scores"""
         | 
| 176 |  | 
| 177 | 
             
                    # Check whether llm can be initialized
         | 
| 178 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 179 |  | 
| 180 | 
             
                    # Initialize the LLM
         | 
| 181 | 
             
                    llm = self._get_llm(model, api_key)
         | 
|  | |
| 123 |  | 
| 124 | 
             
                    # Check whether the model is available
         | 
| 125 | 
             
                    print(provider,model)
         | 
| 126 | 
            +
                    
         | 
| 127 | 
            +
                    if provider == "openai":
         | 
| 128 | 
            +
                        client = openai.OpenAI(api_key=api_key)
         | 
| 129 | 
            +
                        model_names = set([model.id for model in client.models.list()])
         | 
| 130 | 
            +
                        if model not in model_names:
         | 
| 131 | 
            +
                            raise ValueError(f"Model {model} not found for provider {provider}, available models: {model_names}")
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 132 |  | 
| 133 | 
            +
                    elif provider == "deepseek":
         | 
| 134 | 
            +
                        print("Checking DeepSeek model")
         | 
| 135 | 
            +
                        client = openai.OpenAI(api_key=api_key,base_url="https://api.deepseek.com")
         | 
| 136 | 
            +
                        model_names = [model.id for model in client.models.list()]
         | 
| 137 | 
            +
                        print(model_names)  
         | 
| 138 | 
            +
                        if model not in model_names:
         | 
| 139 | 
            +
                            raise ValueError(f"Model {model} not found for provider {provider}, available models: {model_names}")
         | 
| 140 |  | 
| 141 | 
            +
                    elif provider == "xai":
         | 
| 142 | 
            +
                        client = openai.OpenAI(api_key=api_key, base_url="https://api.xai.com")
         | 
| 143 | 
            +
                        model_names = [model.id for model in client.models.list()]
         | 
| 144 | 
            +
                        print(model_names)
         | 
| 145 | 
            +
                        if model not in model_names:
         | 
| 146 | 
            +
                            raise ValueError(f"Model {model} not found for provider {provider}, available models: {model_names}")
         | 
| 147 | 
            +
                
         | 
| 148 |  | 
| 149 | 
             
                    assert len(questions) == len(predictions) == len(references), "Questions, predictions and references must have the same length"
         | 
| 150 |  | 
|  | |
| 168 | 
             
                    """Returns the scores"""
         | 
| 169 |  | 
| 170 | 
             
                    # Check whether llm can be initialized
         | 
| 171 | 
            +
                    try:
         | 
| 172 | 
            +
                        self._verify_input(questions, predictions, references, provider, api_key, model)
         | 
| 173 | 
            +
                    except ValueError as e:
         | 
| 174 | 
            +
                        return {"error": str(e)}
         | 
| 175 | 
            +
                    except openai.AuthenticationError as e:
         | 
| 176 | 
            +
                        message = e.body["message"]
         | 
| 177 | 
            +
                        return {"error": f"Authentication failed: {message}"}
         | 
| 178 | 
            +
                    except Exception as e:
         | 
| 179 | 
            +
                        return {"error": f"An error occurred when verifying the provider/model match: {e}"}
         | 
| 180 |  | 
| 181 | 
             
                    # Initialize the LLM
         | 
| 182 | 
             
                    llm = self._get_llm(model, api_key)
         |