DonImages commited on
Commit
bb2b425
·
verified ·
1 Parent(s): 18da866

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import InferenceClient
3
+ import base64
4
+ import json
5
+
6
+ class LoRAInferenceWrapper:
7
+ def __init__(self, model_id, lora_file_path):
8
+ """
9
+ Initialize the LoRA Inference Wrapper.
10
+
11
+ Args:
12
+ model_id (str): Hugging Face model ID (e.g., stabilityai/stable-diffusion-3.5-large).
13
+ lora_file_path (str): Path to the local LoRA file.
14
+ """
15
+ # Retrieve the API token from environment variables
16
+ token = os.getenv("HF_TOKEN")
17
+ if not token:
18
+ raise ValueError("HF_TOKEN is not set. Add it as a secret in the Hugging Face Space.")
19
+ self.client = InferenceClient(model_id, token=token)
20
+ self.lora_file_path = lora_file_path
21
+
22
+ def load_lora_weights(self):
23
+ # Load LoRA weights (same as before)
24
+ try:
25
+ with open(self.lora_file_path, "rb") as f:
26
+ lora_weights = f.read()
27
+ return base64.b64encode(lora_weights).decode('utf-8')
28
+ except FileNotFoundError:
29
+ raise Exception(f"LoRA file not found at path: {self.lora_file_path}")
30
+
31
+ def generate_with_lora(self, prompt):
32
+ # Generate image (same as before)
33
+ processed_lora = self.load_lora_weights()
34
+ extended_prompt = json.dumps({
35
+ "prompt": prompt,
36
+ "lora": processed_lora
37
+ })
38
+ result = self.client.text_to_image(prompt=extended_prompt)
39
+ return result
40
+
41
+ # Example usage
42
+ if __name__ == "__main__":
43
+ model_id = "stabilityai/stable-diffusion-3.5-large"
44
+ lora_file_path = "./lora_file.pth"
45
+
46
+ lora_client = LoRAInferenceWrapper(model_id, lora_file_path)
47
+
48
+ prompt = "A futuristic city skyline with neon lights."
49
+ try:
50
+ result = lora_client.generate_with_lora(prompt)
51
+ print("Generated image result:", result)
52
+ except Exception as e:
53
+ print("Error:", str(e))