|
import os |
|
import base64 |
|
import requests |
|
from typing import Dict, Any, Optional |
|
from PIL import Image |
|
import io |
|
|
|
class HuggingFaceInferenceClient: |
|
""" |
|
Comprehensive client for interacting with Hugging Face Inference API endpoints. |
|
|
|
## Core Features |
|
- Secure API authentication |
|
- Flexible image encoding |
|
- Advanced error handling |
|
- Configurable generation parameters |
|
|
|
## Technical Design Considerations |
|
- Environment-based configuration |
|
- Type-hinted method signatures |
|
- Comprehensive logging and error management |
|
""" |
|
|
|
def __init__( |
|
self, |
|
api_url: Optional[str] = None, |
|
api_token: Optional[str] = None |
|
): |
|
""" |
|
Initialize Hugging Face Inference API client. |
|
|
|
Args: |
|
api_url (str, optional): Inference endpoint URL |
|
api_token (str, optional): Authentication token |
|
""" |
|
self.api_url = api_url or os.getenv('HF_INFERENCE_ENDPOINT') |
|
self.api_token = api_token or os.getenv('HF_API_TOKEN') |
|
|
|
if not self.api_url or not self.api_token: |
|
raise ValueError( |
|
"Missing Hugging Face Inference endpoint or API token. " |
|
"Please provide via parameters or environment variables." |
|
) |
|
|
|
def encode_image( |
|
self, |
|
image_path: str, |
|
format: str = 'JPEG' |
|
) -> str: |
|
""" |
|
Encode image to base64 data URI. |
|
|
|
Args: |
|
image_path (str): Path to input image |
|
format (str): Output image format |
|
|
|
Returns: |
|
str: Base64 encoded data URI |
|
""" |
|
try: |
|
with Image.open(image_path) as img: |
|
|
|
if img.mode != "RGB": |
|
img = img.convert("RGB") |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
img.save(img_byte_arr, format=format) |
|
|
|
|
|
base64_encoded = base64.b64encode( |
|
img_byte_arr.getvalue() |
|
).decode('utf-8') |
|
|
|
return f"data:image/{format.lower()};base64,{base64_encoded}" |
|
|
|
except Exception as e: |
|
raise ValueError(f"Image encoding failed: {e}") |
|
|
|
def generate_image( |
|
self, |
|
payload: Dict[str, Any] |
|
) -> Dict[str, Any]: |
|
""" |
|
Execute image generation request. |
|
|
|
Args: |
|
payload (Dict): Generation configuration payload |
|
|
|
Returns: |
|
Dict: API response containing generation results |
|
""" |
|
headers = { |
|
"Accept": "application/json", |
|
"Authorization": f"Bearer {self.api_token}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
try: |
|
response = requests.post( |
|
self.api_url, |
|
headers=headers, |
|
json=payload |
|
) |
|
response.raise_for_status() |
|
return response.json() |
|
|
|
except requests.RequestException as e: |
|
return { |
|
"error": f"API request failed: {e}", |
|
"status_code": response.status_code if 'response' in locals() else None |
|
} |
|
|
|
def save_generated_media( |
|
self, |
|
response: Dict[str, Any], |
|
output_filename: str |
|
) -> Optional[str]: |
|
""" |
|
Save generated media from API response. |
|
|
|
Args: |
|
response (Dict): API generation response |
|
output_filename (str): Output file path |
|
|
|
Returns: |
|
Optional[str]: Path to saved file or None |
|
""" |
|
media_types = { |
|
'image': self._save_image, |
|
'video': self._save_video |
|
} |
|
|
|
try: |
|
|
|
if 'error' in response: |
|
print(f"Generation Error: {response['error']}") |
|
return None |
|
|
|
|
|
for media_type, save_func in media_types.items(): |
|
if media_type in response: |
|
return save_func(response[media_type], output_filename) |
|
|
|
raise ValueError("No supported media found in response") |
|
|
|
except Exception as e: |
|
print(f"Media saving failed: {e}") |
|
return None |
|
|
|
def _save_image( |
|
self, |
|
image_data_uri: str, |
|
output_path: str |
|
) -> str: |
|
""" |
|
Save base64 encoded image data. |
|
|
|
Args: |
|
image_data_uri (str): Base64 image data URI |
|
output_path (str): Output image file path |
|
|
|
Returns: |
|
str: Path to saved image |
|
""" |
|
|
|
base64_data = image_data_uri.split(",")[1] |
|
image_data = base64.b64decode(base64_data) |
|
|
|
with open(output_path, "wb") as f: |
|
f.write(image_data) |
|
|
|
return output_path |
|
|
|
def _save_video( |
|
self, |
|
video_data_uri: str, |
|
output_path: str |
|
) -> str: |
|
""" |
|
Save base64 encoded video data. |
|
|
|
Args: |
|
video_data_uri (str): Base64 video data URI |
|
output_path (str): Output video file path |
|
|
|
Returns: |
|
str: Path to saved video |
|
""" |
|
|
|
base64_data = video_data_uri.split(",")[1] |
|
video_data = base64.b64decode(base64_data) |
|
|
|
with open(output_path, "wb") as f: |
|
f.write(video_data) |
|
|
|
return output_path |
|
|
|
def main(): |
|
""" |
|
Example usage demonstrating client capabilities. |
|
""" |
|
|
|
client = HuggingFaceInferenceClient( |
|
api_url="https://your-endpoint.endpoints.huggingface.cloud", |
|
api_token="hf_your_token_here" |
|
) |
|
|
|
|
|
image_generation_config = { |
|
"inputs": { |
|
"image": client.encode_image("input_image.jpg"), |
|
"prompt": "Enhance and expand the scene creatively" |
|
}, |
|
"parameters": { |
|
|
|
"width": 768, |
|
"height": 480, |
|
"num_frames": 129, |
|
"num_inference_steps": 50, |
|
"guidance_scale": 4.0, |
|
"double_num_frames": True, |
|
"fps": 60, |
|
"super_resolution": True, |
|
"grain_amount": 12 |
|
} |
|
} |
|
|
|
|
|
generation_output = client.generate_image(image_generation_config) |
|
|
|
|
|
output_filename = client.save_generated_media( |
|
generation_output, |
|
"output_media.mp4" |
|
) |
|
|
|
print(f"Media saved to: {output_filename}") |
|
|
|
if __name__ == "__main__": |
|
main() |