|
import os |
|
from huggingface_hub import InferenceClient |
|
import base64 |
|
import json |
|
|
|
class LoRAInferenceWrapper: |
|
def __init__(self, model_id, lora_file_path): |
|
""" |
|
Initialize the LoRA Inference Wrapper. |
|
|
|
Args: |
|
model_id (str): Hugging Face model ID (e.g., stabilityai/stable-diffusion-3.5-large). |
|
lora_file_path (str): Path to the local LoRA file. |
|
""" |
|
|
|
token = os.getenv("HF_TOKEN") |
|
if not token: |
|
raise ValueError("HF_TOKEN is not set. Add it as a secret in the Hugging Face Space.") |
|
self.client = InferenceClient(model_id, token=token) |
|
self.lora_file_path = lora_file_path |
|
|
|
def load_lora_weights(self): |
|
|
|
try: |
|
with open(self.lora_file_path, "rb") as f: |
|
lora_weights = f.read() |
|
return base64.b64encode(lora_weights).decode('utf-8') |
|
except FileNotFoundError: |
|
raise Exception(f"LoRA file not found at path: {self.lora_file_path}") |
|
|
|
def generate_with_lora(self, prompt): |
|
|
|
processed_lora = self.load_lora_weights() |
|
extended_prompt = json.dumps({ |
|
"prompt": prompt, |
|
"lora": processed_lora |
|
}) |
|
result = self.client.text_to_image(prompt=extended_prompt) |
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
model_id = "stabilityai/stable-diffusion-3.5-large" |
|
lora_file_path = "./lora_file.pth" |
|
|
|
lora_client = LoRAInferenceWrapper(model_id, lora_file_path) |
|
|
|
prompt = "A futuristic city skyline with neon lights." |
|
try: |
|
result = lora_client.generate_with_lora(prompt) |
|
print("Generated image result:", result) |
|
except Exception as e: |
|
print("Error:", str(e)) |
|
|