3Luik / huggingfaceinferenceclient.py
K00B404's picture
Create huggingfaceinferenceclient.py
8ae56e8 verified
import os
import base64
import requests
from typing import Dict, Any, Optional
from PIL import Image
import io
class HuggingFaceInferenceClient:
"""
Comprehensive client for interacting with Hugging Face Inference API endpoints.
## Core Features
- Secure API authentication
- Flexible image encoding
- Advanced error handling
- Configurable generation parameters
## Technical Design Considerations
- Environment-based configuration
- Type-hinted method signatures
- Comprehensive logging and error management
"""
def __init__(
self,
api_url: Optional[str] = None,
api_token: Optional[str] = None
):
"""
Initialize Hugging Face Inference API client.
Args:
api_url (str, optional): Inference endpoint URL
api_token (str, optional): Authentication token
"""
self.api_url = api_url or os.getenv('HF_INFERENCE_ENDPOINT')
self.api_token = api_token or os.getenv('HF_API_TOKEN')
if not self.api_url or not self.api_token:
raise ValueError(
"Missing Hugging Face Inference endpoint or API token. "
"Please provide via parameters or environment variables."
)
def encode_image(
self,
image_path: str,
format: str = 'JPEG'
) -> str:
"""
Encode image to base64 data URI.
Args:
image_path (str): Path to input image
format (str): Output image format
Returns:
str: Base64 encoded data URI
"""
try:
with Image.open(image_path) as img:
# Ensure RGB compatibility
if img.mode != "RGB":
img = img.convert("RGB")
# Convert to byte array
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format=format)
# Encode to base64
base64_encoded = base64.b64encode(
img_byte_arr.getvalue()
).decode('utf-8')
return f"data:image/{format.lower()};base64,{base64_encoded}"
except Exception as e:
raise ValueError(f"Image encoding failed: {e}")
def generate_image(
self,
payload: Dict[str, Any]
) -> Dict[str, Any]:
"""
Execute image generation request.
Args:
payload (Dict): Generation configuration payload
Returns:
Dict: API response containing generation results
"""
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {self.api_token}",
"Content-Type": "application/json"
}
try:
response = requests.post(
self.api_url,
headers=headers,
json=payload
)
response.raise_for_status()
return response.json()
except requests.RequestException as e:
return {
"error": f"API request failed: {e}",
"status_code": response.status_code if 'response' in locals() else None
}
def save_generated_media(
self,
response: Dict[str, Any],
output_filename: str
) -> Optional[str]:
"""
Save generated media from API response.
Args:
response (Dict): API generation response
output_filename (str): Output file path
Returns:
Optional[str]: Path to saved file or None
"""
media_types = {
'image': self._save_image,
'video': self._save_video
}
try:
# Check for errors
if 'error' in response:
print(f"Generation Error: {response['error']}")
return None
# Detect media type and save
for media_type, save_func in media_types.items():
if media_type in response:
return save_func(response[media_type], output_filename)
raise ValueError("No supported media found in response")
except Exception as e:
print(f"Media saving failed: {e}")
return None
def _save_image(
self,
image_data_uri: str,
output_path: str
) -> str:
"""
Save base64 encoded image data.
Args:
image_data_uri (str): Base64 image data URI
output_path (str): Output image file path
Returns:
str: Path to saved image
"""
# Remove data URI prefix
base64_data = image_data_uri.split(",")[1]
image_data = base64.b64decode(base64_data)
with open(output_path, "wb") as f:
f.write(image_data)
return output_path
def _save_video(
self,
video_data_uri: str,
output_path: str
) -> str:
"""
Save base64 encoded video data.
Args:
video_data_uri (str): Base64 video data URI
output_path (str): Output video file path
Returns:
str: Path to saved video
"""
# Remove data URI prefix
base64_data = video_data_uri.split(",")[1]
video_data = base64.b64decode(base64_data)
with open(output_path, "wb") as f:
f.write(video_data)
return output_path
def main():
"""
Example usage demonstrating client capabilities.
"""
# Initialize client with endpoint and token
client = HuggingFaceInferenceClient(
api_url="https://your-endpoint.endpoints.huggingface.cloud",
api_token="hf_your_token_here"
)
# Prepare generation payload
image_generation_config = {
"inputs": {
"image": client.encode_image("input_image.jpg"),
"prompt": "Enhance and expand the scene creatively"
},
"parameters": {
# Configurable generation parameters
"width": 768,
"height": 480,
"num_frames": 129, # 8*16 + 1
"num_inference_steps": 50,
"guidance_scale": 4.0,
"double_num_frames": True,
"fps": 60,
"super_resolution": True,
"grain_amount": 12
}
}
# Generate media
generation_output = client.generate_image(image_generation_config)
# Save generated media
output_filename = client.save_generated_media(
generation_output,
"output_media.mp4"
)
print(f"Media saved to: {output_filename}")
if __name__ == "__main__":
main()