DonImages commited on
Commit
dd28fa5
·
verified ·
1 Parent(s): bb2b425

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -35
app.py CHANGED
@@ -1,53 +1,55 @@
1
- import os
2
  from huggingface_hub import InferenceClient
 
3
  import base64
4
  import json
5
 
6
  class LoRAInferenceWrapper:
7
- def __init__(self, model_id, lora_file_path):
8
- """
9
- Initialize the LoRA Inference Wrapper.
10
-
11
- Args:
12
- model_id (str): Hugging Face model ID (e.g., stabilityai/stable-diffusion-3.5-large).
13
- lora_file_path (str): Path to the local LoRA file.
14
- """
15
- # Retrieve the API token from environment variables
16
- token = os.getenv("HF_TOKEN")
17
- if not token:
18
- raise ValueError("HF_TOKEN is not set. Add it as a secret in the Hugging Face Space.")
19
  self.client = InferenceClient(model_id, token=token)
20
- self.lora_file_path = lora_file_path
21
 
22
  def load_lora_weights(self):
23
- # Load LoRA weights (same as before)
24
- try:
25
- with open(self.lora_file_path, "rb") as f:
26
- lora_weights = f.read()
27
- return base64.b64encode(lora_weights).decode('utf-8')
28
- except FileNotFoundError:
29
- raise Exception(f"LoRA file not found at path: {self.lora_file_path}")
 
 
 
 
 
 
 
30
 
31
  def generate_with_lora(self, prompt):
32
- # Generate image (same as before)
33
- processed_lora = self.load_lora_weights()
 
 
 
34
  extended_prompt = json.dumps({
35
  "prompt": prompt,
36
  "lora": processed_lora
37
  })
 
 
38
  result = self.client.text_to_image(prompt=extended_prompt)
39
  return result
40
 
41
  # Example usage
42
- if __name__ == "__main__":
43
- model_id = "stabilityai/stable-diffusion-3.5-large"
44
- lora_file_path = "./lora_file.pth"
45
-
46
- lora_client = LoRAInferenceWrapper(model_id, lora_file_path)
47
-
48
- prompt = "A futuristic city skyline with neon lights."
49
- try:
50
- result = lora_client.generate_with_lora(prompt)
51
- print("Generated image result:", result)
52
- except Exception as e:
53
- print("Error:", str(e))
 
 
 
1
  from huggingface_hub import InferenceClient
2
+ import os
3
  import base64
4
  import json
5
 
6
  class LoRAInferenceWrapper:
7
+ def __init__(self, model_id, token):
8
+ # Initialize the InferenceClient
 
 
 
 
 
 
 
 
 
 
9
  self.client = InferenceClient(model_id, token=token)
 
10
 
11
  def load_lora_weights(self):
12
+ # Define the path to the LoRA model
13
+ lora_model_path = "./lora_file.pth" # Assuming the file is saved locally
14
+
15
+ # Check if the file exists
16
+ if not os.path.exists(lora_model_path):
17
+ raise FileNotFoundError(f"LoRA file not found at path: {lora_model_path}")
18
+
19
+ # Load the LoRA weights from the local file
20
+ with open(lora_model_path, "rb") as f:
21
+ return f.read() # Return the raw bytes of the LoRA file
22
+
23
+ def preprocess_lora_weights(self, lora_weights):
24
+ # Preprocess the LoRA weights (e.g., Base64 encoding for JSON compatibility)
25
+ return base64.b64encode(lora_weights).decode("utf-8")
26
 
27
  def generate_with_lora(self, prompt):
28
+ # Load and preprocess the LoRA weights
29
+ lora_weights = self.load_lora_weights()
30
+ processed_lora = self.preprocess_lora_weights(lora_weights)
31
+
32
+ # Combine the prompt and LoRA data as a single input
33
  extended_prompt = json.dumps({
34
  "prompt": prompt,
35
  "lora": processed_lora
36
  })
37
+
38
+ # Generate the output using the InferenceClient
39
  result = self.client.text_to_image(prompt=extended_prompt)
40
  return result
41
 
42
  # Example usage
43
+ model_id = "stabilityai/stable-diffusion-3.5-large"
44
+ token = "hf_YOUR_HF_API_TOKEN" # Replace with your Hugging Face token
45
+
46
+ # Initialize the wrapper
47
+ lora_client = LoRAInferenceWrapper(model_id, token)
48
+
49
+ # Generate an image with the LoRA file applied
50
+ prompt = "The same woman, smiling at the beach."
51
+ try:
52
+ result = lora_client.generate_with_lora(prompt)
53
+ print("Generated image:", result)
54
+ except Exception as e:
55
+ print("Error:", str(e))