|
from huggingface_hub import InferenceClient |
|
import os |
|
import base64 |
|
import json |
|
|
|
class LoRAInferenceWrapper: |
|
def __init__(self, model_id, token): |
|
|
|
self.client = InferenceClient(model_id, token=token) |
|
|
|
def load_lora_weights(self): |
|
|
|
lora_model_path = "./lora.model.pth" |
|
|
|
|
|
if os.path.exists(lora_model_path): |
|
print(f"Found LoRA model at: {lora_model_path}") |
|
with open(lora_model_path, 'rb') as f: |
|
return f.read() |
|
else: |
|
raise FileNotFoundError(f"LoRA model not found at path: {lora_model_path}") |
|
|
|
def preprocess_lora_weights(self, lora_weights): |
|
|
|
return base64.b64encode(lora_weights).decode("utf-8") |
|
|
|
def generate_with_lora(self, prompt): |
|
|
|
lora_weights = self.load_lora_weights() |
|
processed_lora = self.preprocess_lora_weights(lora_weights) |
|
|
|
|
|
extended_prompt = json.dumps({ |
|
"prompt": prompt, |
|
"lora": processed_lora |
|
}) |
|
|
|
|
|
result = self.client.text_to_image(prompt=extended_prompt) |
|
return result |
|
|
|
|
|
model_id = "stabilityai/stable-diffusion-3.5-large" |
|
token = "hf_YOUR_HF_API_TOKEN" |
|
|
|
|
|
lora_client = LoRAInferenceWrapper(model_id, token) |
|
|
|
|
|
prompt = "The same woman, smiling at the beach." |
|
try: |
|
result = lora_client.generate_with_lora(prompt) |
|
print("Generated image:", result) |
|
except Exception as e: |
|
print("Error:", str(e)) |
|
|