Testing3 / app.py
DonImages's picture
Create app.py
bb2b425 verified
raw
history blame
1.89 kB
import os
from huggingface_hub import InferenceClient
import base64
import json
class LoRAInferenceWrapper:
def __init__(self, model_id, lora_file_path):
"""
Initialize the LoRA Inference Wrapper.
Args:
model_id (str): Hugging Face model ID (e.g., stabilityai/stable-diffusion-3.5-large).
lora_file_path (str): Path to the local LoRA file.
"""
# Retrieve the API token from environment variables
token = os.getenv("HF_TOKEN")
if not token:
raise ValueError("HF_TOKEN is not set. Add it as a secret in the Hugging Face Space.")
self.client = InferenceClient(model_id, token=token)
self.lora_file_path = lora_file_path
def load_lora_weights(self):
# Load LoRA weights (same as before)
try:
with open(self.lora_file_path, "rb") as f:
lora_weights = f.read()
return base64.b64encode(lora_weights).decode('utf-8')
except FileNotFoundError:
raise Exception(f"LoRA file not found at path: {self.lora_file_path}")
def generate_with_lora(self, prompt):
# Generate image (same as before)
processed_lora = self.load_lora_weights()
extended_prompt = json.dumps({
"prompt": prompt,
"lora": processed_lora
})
result = self.client.text_to_image(prompt=extended_prompt)
return result
# Example usage
if __name__ == "__main__":
model_id = "stabilityai/stable-diffusion-3.5-large"
lora_file_path = "./lora_file.pth"
lora_client = LoRAInferenceWrapper(model_id, lora_file_path)
prompt = "A futuristic city skyline with neon lights."
try:
result = lora_client.generate_with_lora(prompt)
print("Generated image result:", result)
except Exception as e:
print("Error:", str(e))