Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
b1c0860
1
Parent(s):
eeb2755
add init
Browse files- app.py +68 -41
- llm_inference_video.py +66 -22
app.py
CHANGED
@@ -1,23 +1,73 @@
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
-
|
4 |
-
|
5 |
-
# Initialize the VLMCaptioning model once at startup
|
6 |
-
print("Initializing Video Prompt Generator...")
|
7 |
-
vlm_captioner = VLMCaptioning()
|
8 |
-
print("Video Prompt Generator initialized successfully!")
|
9 |
-
|
10 |
-
# Import VideoLLMInferenceNode after VLMCaptioning initialization
|
11 |
-
from llm_inference_video import VideoLLMInferenceNode
|
12 |
|
|
|
13 |
title = """<h1 align="center">AI Video Prompt Generator</h1>
|
14 |
<p align="center">Generate creative video prompts with technical specifications</p>
|
15 |
<p align="center">You can use prompts with Kling, MiniMax, Hunyuan, Haiper, CogVideoX, Luma, LTX, Runway, PixVerse. </p>"""
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
|
22 |
gr.HTML(title)
|
23 |
|
@@ -128,7 +178,7 @@ def create_video_interface():
|
|
128 |
provider.change(update_models, inputs=provider, outputs=model)
|
129 |
|
130 |
generate_btn.click(
|
131 |
-
|
132 |
inputs=[input_concept, style, camera_style, camera_direction, pacing, special_effects,
|
133 |
custom_elements, provider, model, prompt_length],
|
134 |
outputs=output
|
@@ -151,45 +201,22 @@ def create_video_interface():
|
|
151 |
analyze_video_btn = gr.Button("Analyze Video")
|
152 |
video_output = gr.Textbox(label="Video Analysis", lines=10)
|
153 |
|
154 |
-
# Use
|
155 |
analyze_image_btn.click(
|
156 |
-
|
157 |
inputs=[image_input, image_question],
|
158 |
outputs=image_output
|
159 |
)
|
160 |
|
161 |
analyze_video_btn.click(
|
162 |
-
|
163 |
inputs=video_input,
|
164 |
outputs=video_output
|
165 |
)
|
166 |
|
167 |
return demo
|
168 |
|
169 |
-
# Define these functions at the module level to avoid pickling issues
|
170 |
-
def describe_image_interface(image, question="Describe this image in detail."):
|
171 |
-
"""Interface function for image description"""
|
172 |
-
if image is None:
|
173 |
-
return "Please upload an image."
|
174 |
-
|
175 |
-
if not question or question.strip() == "":
|
176 |
-
question = "Describe this image in detail."
|
177 |
-
|
178 |
-
return vlm_captioner.describe_image(
|
179 |
-
image=image,
|
180 |
-
question=question
|
181 |
-
)
|
182 |
-
|
183 |
-
def describe_video_interface(video, frame_interval=30):
|
184 |
-
"""Interface function for video description"""
|
185 |
-
if video is None:
|
186 |
-
return "Please upload a video."
|
187 |
-
|
188 |
-
return vlm_captioner.describe_video(
|
189 |
-
video_path=video,
|
190 |
-
frame_interval=frame_interval
|
191 |
-
)
|
192 |
-
|
193 |
if __name__ == "__main__":
|
194 |
demo = create_video_interface()
|
195 |
-
|
|
|
|
1 |
import torch
|
2 |
import gradio as gr
|
3 |
+
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
# Create Gradio UI without loading models first
|
6 |
title = """<h1 align="center">AI Video Prompt Generator</h1>
|
7 |
<p align="center">Generate creative video prompts with technical specifications</p>
|
8 |
<p align="center">You can use prompts with Kling, MiniMax, Hunyuan, Haiper, CogVideoX, Luma, LTX, Runway, PixVerse. </p>"""
|
9 |
|
10 |
+
# Import these at global scope but don't instantiate yet
|
11 |
+
from vlm_captions import VLMCaptioning
|
12 |
+
from llm_inference_video import VideoLLMInferenceNode
|
13 |
|
14 |
+
# Global singleton instances - we'll initialize them only when needed
|
15 |
+
vlm_captioner = None
|
16 |
+
llm_node = None
|
17 |
+
|
18 |
+
# Initialize only once on first use
|
19 |
+
def get_vlm_captioner():
|
20 |
+
global vlm_captioner
|
21 |
+
if vlm_captioner is None:
|
22 |
+
print("Initializing Video Prompt Generator...")
|
23 |
+
vlm_captioner = VLMCaptioning()
|
24 |
+
print("Video Prompt Generator initialized successfully!")
|
25 |
+
return vlm_captioner
|
26 |
+
|
27 |
+
def get_llm_node():
|
28 |
+
global llm_node
|
29 |
+
if llm_node is None:
|
30 |
+
llm_node = VideoLLMInferenceNode()
|
31 |
+
return llm_node
|
32 |
+
|
33 |
+
# Wrapper functions that avoid passing the model between processes
|
34 |
+
@spaces.GPU()
|
35 |
+
def describe_image_wrapper(image, question="Describe this image in detail."):
|
36 |
+
"""GPU-decorated function for image description"""
|
37 |
+
if image is None:
|
38 |
+
return "Please upload an image."
|
39 |
+
|
40 |
+
if not question or question.strip() == "":
|
41 |
+
question = "Describe this image in detail."
|
42 |
+
|
43 |
+
# Get the captioner inside this GPU-decorated function
|
44 |
+
vlm = get_vlm_captioner()
|
45 |
+
return vlm.describe_image(image=image, question=question)
|
46 |
+
|
47 |
+
@spaces.GPU()
|
48 |
+
def describe_video_wrapper(video, frame_interval=30):
|
49 |
+
"""GPU-decorated function for video description"""
|
50 |
+
if video is None:
|
51 |
+
return "Please upload a video."
|
52 |
+
|
53 |
+
# Get the captioner inside this GPU-decorated function
|
54 |
+
vlm = get_vlm_captioner()
|
55 |
+
return vlm.describe_video(video_path=video, frame_interval=frame_interval)
|
56 |
+
|
57 |
+
def generate_video_prompt_wrapper(
|
58 |
+
concept, style, camera_style, camera_direction,
|
59 |
+
pacing, special_effects, custom_elements,
|
60 |
+
provider, model, prompt_length
|
61 |
+
):
|
62 |
+
"""Wrapper for LLM prompt generation"""
|
63 |
+
node = get_llm_node()
|
64 |
+
return node.generate_video_prompt(
|
65 |
+
concept, style, camera_style, camera_direction,
|
66 |
+
pacing, special_effects, custom_elements,
|
67 |
+
provider, model, prompt_length
|
68 |
+
)
|
69 |
+
|
70 |
+
def create_video_interface():
|
71 |
with gr.Blocks(theme='bethecloud/storj_theme') as demo:
|
72 |
gr.HTML(title)
|
73 |
|
|
|
178 |
provider.change(update_models, inputs=provider, outputs=model)
|
179 |
|
180 |
generate_btn.click(
|
181 |
+
generate_video_prompt_wrapper,
|
182 |
inputs=[input_concept, style, camera_style, camera_direction, pacing, special_effects,
|
183 |
custom_elements, provider, model, prompt_length],
|
184 |
outputs=output
|
|
|
201 |
analyze_video_btn = gr.Button("Analyze Video")
|
202 |
video_output = gr.Textbox(label="Video Analysis", lines=10)
|
203 |
|
204 |
+
# Use GPU-decorated wrapper functions directly
|
205 |
analyze_image_btn.click(
|
206 |
+
describe_image_wrapper,
|
207 |
inputs=[image_input, image_question],
|
208 |
outputs=image_output
|
209 |
)
|
210 |
|
211 |
analyze_video_btn.click(
|
212 |
+
describe_video_wrapper,
|
213 |
inputs=video_input,
|
214 |
outputs=video_output
|
215 |
)
|
216 |
|
217 |
return demo
|
218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
if __name__ == "__main__":
|
220 |
demo = create_video_interface()
|
221 |
+
# Don't use share=True on Hugging Face Spaces
|
222 |
+
demo.launch()
|
llm_inference_video.py
CHANGED
@@ -7,29 +7,34 @@ import tempfile
|
|
7 |
from PIL import Image
|
8 |
from groq import Groq
|
9 |
from openai import OpenAI
|
10 |
-
|
11 |
|
12 |
class VideoLLMInferenceNode:
|
13 |
-
def __init__(self
|
14 |
"""
|
15 |
-
Initialize the VideoLLMInferenceNode
|
16 |
-
|
17 |
-
Args:
|
18 |
-
vlm_captioner: The already initialized VLMCaptioning instance to use
|
19 |
"""
|
20 |
-
self.vlm = vlm_captioner
|
21 |
self.sambanova_api_key = os.environ.get("SAMBANOVA_API_KEY", "")
|
22 |
self.groq_api_key = os.environ.get("GROQ_API_KEY", "")
|
23 |
|
24 |
-
|
25 |
-
self.
|
26 |
-
api_key=self.
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
|
|
30 |
def analyze_image(self, image_path: str, question: Optional[str] = None) -> str:
|
31 |
"""
|
32 |
-
Analyze an image using
|
33 |
|
34 |
Args:
|
35 |
image_path: Path to the image file
|
@@ -45,14 +50,17 @@ class VideoLLMInferenceNode:
|
|
45 |
question = "Describe this image in detail."
|
46 |
|
47 |
try:
|
48 |
-
#
|
49 |
-
|
|
|
|
|
50 |
except Exception as e:
|
51 |
return f"Error analyzing image: {str(e)}"
|
52 |
|
|
|
53 |
def analyze_video(self, video_path: str) -> str:
|
54 |
"""
|
55 |
-
Analyze a video using
|
56 |
|
57 |
Args:
|
58 |
video_path: Path to the video file
|
@@ -64,8 +72,10 @@ class VideoLLMInferenceNode:
|
|
64 |
return "Please upload a video."
|
65 |
|
66 |
try:
|
67 |
-
#
|
68 |
-
|
|
|
|
|
69 |
except Exception as e:
|
70 |
return f"Error analyzing video: {str(e)}"
|
71 |
|
@@ -147,16 +157,36 @@ The prompt should be detailed and technical, specifically mentioning camera angl
|
|
147 |
# Call the appropriate API based on provider
|
148 |
try:
|
149 |
if provider == "SambaNova":
|
150 |
-
|
|
|
|
|
|
|
151 |
elif provider == "Groq":
|
152 |
-
|
|
|
|
|
|
|
153 |
else:
|
154 |
return "Unsupported provider. Please select SambaNova or Groq."
|
155 |
except Exception as e:
|
156 |
return f"Error generating prompt: {str(e)}"
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
def _call_sambanova_api(self, system_message: str, user_message: str, model: str) -> str:
|
159 |
-
"""Call the SambaNova API
|
160 |
if not self.sambanova_api_key:
|
161 |
return "SambaNova API key not configured. Please set the SAMBANOVA_API_KEY environment variable."
|
162 |
|
@@ -182,8 +212,22 @@ The prompt should be detailed and technical, specifically mentioning camera angl
|
|
182 |
else:
|
183 |
return f"Error from SambaNova API: {response.status_code} - {response.text}"
|
184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
def _call_groq_api(self, system_message: str, user_message: str, model: str) -> str:
|
186 |
-
"""Call the Groq API
|
187 |
if not self.groq_api_key:
|
188 |
return "Groq API key not configured. Please set the GROQ_API_KEY environment variable."
|
189 |
|
|
|
7 |
from PIL import Image
|
8 |
from groq import Groq
|
9 |
from openai import OpenAI
|
10 |
+
import spaces
|
11 |
|
12 |
class VideoLLMInferenceNode:
|
13 |
+
def __init__(self):
|
14 |
"""
|
15 |
+
Initialize the VideoLLMInferenceNode without VLM captioning dependency
|
|
|
|
|
|
|
16 |
"""
|
|
|
17 |
self.sambanova_api_key = os.environ.get("SAMBANOVA_API_KEY", "")
|
18 |
self.groq_api_key = os.environ.get("GROQ_API_KEY", "")
|
19 |
|
20 |
+
# Initialize API clients if keys are available
|
21 |
+
if self.groq_api_key:
|
22 |
+
self.groq_client = Groq(api_key=self.groq_api_key)
|
23 |
+
else:
|
24 |
+
self.groq_client = None
|
25 |
+
|
26 |
+
if self.sambanova_api_key:
|
27 |
+
self.sambanova_client = OpenAI(
|
28 |
+
api_key=self.sambanova_api_key,
|
29 |
+
base_url="https://api.sambanova.ai/v1",
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
self.sambanova_client = None
|
33 |
|
34 |
+
@spaces.GPU()
|
35 |
def analyze_image(self, image_path: str, question: Optional[str] = None) -> str:
|
36 |
"""
|
37 |
+
Analyze an image using VLM model directly
|
38 |
|
39 |
Args:
|
40 |
image_path: Path to the image file
|
|
|
50 |
question = "Describe this image in detail."
|
51 |
|
52 |
try:
|
53 |
+
# Import and use VLMCaptioning within this GPU-scoped function
|
54 |
+
from app import get_vlm_captioner
|
55 |
+
vlm = get_vlm_captioner()
|
56 |
+
return vlm.describe_image(image_path, question)
|
57 |
except Exception as e:
|
58 |
return f"Error analyzing image: {str(e)}"
|
59 |
|
60 |
+
@spaces.GPU()
|
61 |
def analyze_video(self, video_path: str) -> str:
|
62 |
"""
|
63 |
+
Analyze a video using VLM model directly
|
64 |
|
65 |
Args:
|
66 |
video_path: Path to the video file
|
|
|
72 |
return "Please upload a video."
|
73 |
|
74 |
try:
|
75 |
+
# Import and use VLMCaptioning within this GPU-scoped function
|
76 |
+
from app import get_vlm_captioner
|
77 |
+
vlm = get_vlm_captioner()
|
78 |
+
return vlm.describe_video(video_path)
|
79 |
except Exception as e:
|
80 |
return f"Error analyzing video: {str(e)}"
|
81 |
|
|
|
157 |
# Call the appropriate API based on provider
|
158 |
try:
|
159 |
if provider == "SambaNova":
|
160 |
+
if self.sambanova_client:
|
161 |
+
return self._call_sambanova_client(system_message, user_message, model)
|
162 |
+
else:
|
163 |
+
return self._call_sambanova_api(system_message, user_message, model)
|
164 |
elif provider == "Groq":
|
165 |
+
if self.groq_client:
|
166 |
+
return self._call_groq_client(system_message, user_message, model)
|
167 |
+
else:
|
168 |
+
return self._call_groq_api(system_message, user_message, model)
|
169 |
else:
|
170 |
return "Unsupported provider. Please select SambaNova or Groq."
|
171 |
except Exception as e:
|
172 |
return f"Error generating prompt: {str(e)}"
|
173 |
|
174 |
+
def _call_sambanova_client(self, system_message: str, user_message: str, model: str) -> str:
|
175 |
+
"""Call the SambaNova API using the client library"""
|
176 |
+
try:
|
177 |
+
chat_completion = self.sambanova_client.chat.completions.create(
|
178 |
+
model=model,
|
179 |
+
messages=[
|
180 |
+
{"role": "system", "content": system_message},
|
181 |
+
{"role": "user", "content": user_message}
|
182 |
+
]
|
183 |
+
)
|
184 |
+
return chat_completion.choices[0].message.content
|
185 |
+
except Exception as e:
|
186 |
+
return f"Error from SambaNova API: {str(e)}"
|
187 |
+
|
188 |
def _call_sambanova_api(self, system_message: str, user_message: str, model: str) -> str:
|
189 |
+
"""Call the SambaNova API using direct HTTP requests"""
|
190 |
if not self.sambanova_api_key:
|
191 |
return "SambaNova API key not configured. Please set the SAMBANOVA_API_KEY environment variable."
|
192 |
|
|
|
212 |
else:
|
213 |
return f"Error from SambaNova API: {response.status_code} - {response.text}"
|
214 |
|
215 |
+
def _call_groq_client(self, system_message: str, user_message: str, model: str) -> str:
|
216 |
+
"""Call the Groq API using the client library"""
|
217 |
+
try:
|
218 |
+
chat_completion = self.groq_client.chat.completions.create(
|
219 |
+
model=model,
|
220 |
+
messages=[
|
221 |
+
{"role": "system", "content": system_message},
|
222 |
+
{"role": "user", "content": user_message}
|
223 |
+
]
|
224 |
+
)
|
225 |
+
return chat_completion.choices[0].message.content
|
226 |
+
except Exception as e:
|
227 |
+
return f"Error from Groq API: {str(e)}"
|
228 |
+
|
229 |
def _call_groq_api(self, system_message: str, user_message: str, model: str) -> str:
|
230 |
+
"""Call the Groq API using direct HTTP requests"""
|
231 |
if not self.groq_api_key:
|
232 |
return "Groq API key not configured. Please set the GROQ_API_KEY environment variable."
|
233 |
|