File size: 1,943 Bytes
bb2b425
dd28fa5
bb2b425
 
 
 
dd28fa5
 
bb2b425
 
 
dd28fa5
 
 
 
 
 
 
 
 
 
 
 
 
 
bb2b425
 
dd28fa5
 
 
 
 
bb2b425
 
 
 
dd28fa5
 
bb2b425
 
 
 
dd28fa5
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from huggingface_hub import InferenceClient
import os
import base64
import json

class LoRAInferenceWrapper:
    def __init__(self, model_id, token):
        # Initialize the InferenceClient
        self.client = InferenceClient(model_id, token=token)

    def load_lora_weights(self):
        # Define the path to the LoRA model
        lora_model_path = "./lora_file.pth"  # Assuming the file is saved locally

        # Check if the file exists
        if not os.path.exists(lora_model_path):
            raise FileNotFoundError(f"LoRA file not found at path: {lora_model_path}")

        # Load the LoRA weights from the local file
        with open(lora_model_path, "rb") as f:
            return f.read()  # Return the raw bytes of the LoRA file

    def preprocess_lora_weights(self, lora_weights):
        # Preprocess the LoRA weights (e.g., Base64 encoding for JSON compatibility)
        return base64.b64encode(lora_weights).decode("utf-8")

    def generate_with_lora(self, prompt):
        # Load and preprocess the LoRA weights
        lora_weights = self.load_lora_weights()
        processed_lora = self.preprocess_lora_weights(lora_weights)

        # Combine the prompt and LoRA data as a single input
        extended_prompt = json.dumps({
            "prompt": prompt,
            "lora": processed_lora
        })

        # Generate the output using the InferenceClient
        result = self.client.text_to_image(prompt=extended_prompt)
        return result

# Example usage
model_id = "stabilityai/stable-diffusion-3.5-large"
token = "hf_YOUR_HF_API_TOKEN"  # Replace with your Hugging Face token

# Initialize the wrapper
lora_client = LoRAInferenceWrapper(model_id, token)

# Generate an image with the LoRA file applied
prompt = "The same woman, smiling at the beach."
try:
    result = lora_client.generate_with_lora(prompt)
    print("Generated image:", result)
except Exception as e:
    print("Error:", str(e))