Spaces:
Sleeping
Sleeping
import gradio as gr | |
from pydantic import BaseModel, Field | |
from typing import Any, Optional, Dict, List | |
from huggingface_hub import InferenceClient | |
from langchain.llms.base import LLM | |
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
class KwArgsModel(BaseModel): | |
kwargs: Dict[str, Any] = Field(default_factory=dict) | |
class CustomInferenceClient(LLM, KwArgsModel): | |
model_name: str | |
inference_client: InferenceClient | |
def __init__(self, model_name: str, hf_token: str, kwargs: Optional[Dict[str, Any]] = None): | |
inference_client = InferenceClient(model=model_name, token=hf_token) | |
super().__init__( | |
model_name=model_name, | |
hf_token=hf_token, | |
kwargs=kwargs, | |
inference_client=inference_client | |
) | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None | |
) -> str: | |
if stop is not None: | |
raise ValueError("stop kwargs are not permitted.") | |
response_gen = self.inference_client.text_generation(prompt, **self.kwargs, stream=True) | |
response = ''.join(response_gen) | |
return response | |
def _llm_type(self) -> str: | |
return "custom" | |
def _identifying_params(self) -> dict: | |
return {"model_name": self.model_name} | |
kwargs = {"max_new_tokens":256, "temperature":0.9, "top_p":0.6, "repetition_penalty":1.3, "do_sample":True} |