import requests from pydantic.v1 import SecretStr from typing_extensions import override from langflow.base.models.groq_constants import GROQ_MODELS from langflow.base.models.model import LCModelComponent from langflow.field_typing import LanguageModel from langflow.inputs.inputs import HandleInput from langflow.io import DropdownInput, FloatInput, IntInput, MessageTextInput, SecretStrInput class GroqModel(LCModelComponent): display_name: str = "Groq" description: str = "Generate text using Groq." icon = "Groq" name = "GroqModel" inputs = [ *LCModelComponent._base_inputs, SecretStrInput(name="groq_api_key", display_name="Groq API Key", info="API key for the Groq API."), MessageTextInput( name="groq_api_base", display_name="Groq API Base", info="Base URL path for API requests, leave blank if not using a proxy or service emulator.", advanced=True, value="https://api.groq.com", ), IntInput( name="max_tokens", display_name="Max Output Tokens", info="The maximum number of tokens to generate.", advanced=True, ), FloatInput( name="temperature", display_name="Temperature", info="Run inference with this temperature. Must by in the closed interval [0.0, 1.0].", value=0.1, ), IntInput( name="n", display_name="N", info="Number of chat completions to generate for each prompt. " "Note that the API may not return the full n completions if duplicates are generated.", advanced=True, ), DropdownInput( name="model_name", display_name="Model", info="The name of the model to use.", options=GROQ_MODELS, value="llama-3.1-8b-instant", refresh_button=True, ), HandleInput( name="output_parser", display_name="Output Parser", info="The parser to use to parse the output of the model", advanced=True, input_types=["OutputParser"], ), ] def get_models(self) -> list[str]: api_key = self.groq_api_key base_url = self.groq_api_base or "https://api.groq.com" url = f"{base_url}/openai/v1/models" headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"} try: response = requests.get(url, headers=headers, timeout=10) response.raise_for_status() model_list = response.json() return [model["id"] for model in model_list.get("data", [])] except requests.RequestException as e: self.status = f"Error fetching models: {e}" return GROQ_MODELS @override def update_build_config(self, build_config: dict, field_value: str, field_name: str | None = None): if field_name in {"groq_api_key", "groq_api_base", "model_name"}: models = self.get_models() build_config["model_name"]["options"] = models return build_config def build_model(self) -> LanguageModel: # type: ignore[type-var] try: from langchain_groq import ChatGroq except ImportError as e: msg = "langchain-groq is not installed. Please install it with `pip install langchain-groq`." raise ImportError(msg) from e groq_api_key = self.groq_api_key model_name = self.model_name max_tokens = self.max_tokens temperature = self.temperature groq_api_base = self.groq_api_base n = self.n stream = self.stream return ChatGroq( model=model_name, max_tokens=max_tokens or None, temperature=temperature, base_url=groq_api_base, n=n or 1, api_key=SecretStr(groq_api_key).get_secret_value(), streaming=stream, )