gokaygokay commited on
Commit
b1c0860
·
1 Parent(s): eeb2755
Files changed (2) hide show
  1. app.py +68 -41
  2. llm_inference_video.py +66 -22
app.py CHANGED
@@ -1,23 +1,73 @@
1
  import torch
2
  import gradio as gr
3
- from vlm_captions import VLMCaptioning
4
-
5
- # Initialize the VLMCaptioning model once at startup
6
- print("Initializing Video Prompt Generator...")
7
- vlm_captioner = VLMCaptioning()
8
- print("Video Prompt Generator initialized successfully!")
9
-
10
- # Import VideoLLMInferenceNode after VLMCaptioning initialization
11
- from llm_inference_video import VideoLLMInferenceNode
12
 
 
13
  title = """<h1 align="center">AI Video Prompt Generator</h1>
14
  <p align="center">Generate creative video prompts with technical specifications</p>
15
  <p align="center">You can use prompts with Kling, MiniMax, Hunyuan, Haiper, CogVideoX, Luma, LTX, Runway, PixVerse. </p>"""
16
 
17
- def create_video_interface():
18
- # Pass the already initialized vlm_captioner to avoid serialization issues
19
- llm_node = VideoLLMInferenceNode(vlm_captioner)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  with gr.Blocks(theme='bethecloud/storj_theme') as demo:
22
  gr.HTML(title)
23
 
@@ -128,7 +178,7 @@ def create_video_interface():
128
  provider.change(update_models, inputs=provider, outputs=model)
129
 
130
  generate_btn.click(
131
- llm_node.generate_video_prompt,
132
  inputs=[input_concept, style, camera_style, camera_direction, pacing, special_effects,
133
  custom_elements, provider, model, prompt_length],
134
  outputs=output
@@ -151,45 +201,22 @@ def create_video_interface():
151
  analyze_video_btn = gr.Button("Analyze Video")
152
  video_output = gr.Textbox(label="Video Analysis", lines=10)
153
 
154
- # Use direct function calls to avoid serialization issues
155
  analyze_image_btn.click(
156
- describe_image_interface,
157
  inputs=[image_input, image_question],
158
  outputs=image_output
159
  )
160
 
161
  analyze_video_btn.click(
162
- describe_video_interface,
163
  inputs=video_input,
164
  outputs=video_output
165
  )
166
 
167
  return demo
168
 
169
- # Define these functions at the module level to avoid pickling issues
170
- def describe_image_interface(image, question="Describe this image in detail."):
171
- """Interface function for image description"""
172
- if image is None:
173
- return "Please upload an image."
174
-
175
- if not question or question.strip() == "":
176
- question = "Describe this image in detail."
177
-
178
- return vlm_captioner.describe_image(
179
- image=image,
180
- question=question
181
- )
182
-
183
- def describe_video_interface(video, frame_interval=30):
184
- """Interface function for video description"""
185
- if video is None:
186
- return "Please upload a video."
187
-
188
- return vlm_captioner.describe_video(
189
- video_path=video,
190
- frame_interval=frame_interval
191
- )
192
-
193
  if __name__ == "__main__":
194
  demo = create_video_interface()
195
- demo.launch(share=True)
 
 
1
  import torch
2
  import gradio as gr
3
+ import spaces
 
 
 
 
 
 
 
 
4
 
5
+ # Create Gradio UI without loading models first
6
  title = """<h1 align="center">AI Video Prompt Generator</h1>
7
  <p align="center">Generate creative video prompts with technical specifications</p>
8
  <p align="center">You can use prompts with Kling, MiniMax, Hunyuan, Haiper, CogVideoX, Luma, LTX, Runway, PixVerse. </p>"""
9
 
10
+ # Import these at global scope but don't instantiate yet
11
+ from vlm_captions import VLMCaptioning
12
+ from llm_inference_video import VideoLLMInferenceNode
13
 
14
+ # Global singleton instances - we'll initialize them only when needed
15
+ vlm_captioner = None
16
+ llm_node = None
17
+
18
+ # Initialize only once on first use
19
+ def get_vlm_captioner():
20
+ global vlm_captioner
21
+ if vlm_captioner is None:
22
+ print("Initializing Video Prompt Generator...")
23
+ vlm_captioner = VLMCaptioning()
24
+ print("Video Prompt Generator initialized successfully!")
25
+ return vlm_captioner
26
+
27
+ def get_llm_node():
28
+ global llm_node
29
+ if llm_node is None:
30
+ llm_node = VideoLLMInferenceNode()
31
+ return llm_node
32
+
33
+ # Wrapper functions that avoid passing the model between processes
34
+ @spaces.GPU()
35
+ def describe_image_wrapper(image, question="Describe this image in detail."):
36
+ """GPU-decorated function for image description"""
37
+ if image is None:
38
+ return "Please upload an image."
39
+
40
+ if not question or question.strip() == "":
41
+ question = "Describe this image in detail."
42
+
43
+ # Get the captioner inside this GPU-decorated function
44
+ vlm = get_vlm_captioner()
45
+ return vlm.describe_image(image=image, question=question)
46
+
47
+ @spaces.GPU()
48
+ def describe_video_wrapper(video, frame_interval=30):
49
+ """GPU-decorated function for video description"""
50
+ if video is None:
51
+ return "Please upload a video."
52
+
53
+ # Get the captioner inside this GPU-decorated function
54
+ vlm = get_vlm_captioner()
55
+ return vlm.describe_video(video_path=video, frame_interval=frame_interval)
56
+
57
+ def generate_video_prompt_wrapper(
58
+ concept, style, camera_style, camera_direction,
59
+ pacing, special_effects, custom_elements,
60
+ provider, model, prompt_length
61
+ ):
62
+ """Wrapper for LLM prompt generation"""
63
+ node = get_llm_node()
64
+ return node.generate_video_prompt(
65
+ concept, style, camera_style, camera_direction,
66
+ pacing, special_effects, custom_elements,
67
+ provider, model, prompt_length
68
+ )
69
+
70
+ def create_video_interface():
71
  with gr.Blocks(theme='bethecloud/storj_theme') as demo:
72
  gr.HTML(title)
73
 
 
178
  provider.change(update_models, inputs=provider, outputs=model)
179
 
180
  generate_btn.click(
181
+ generate_video_prompt_wrapper,
182
  inputs=[input_concept, style, camera_style, camera_direction, pacing, special_effects,
183
  custom_elements, provider, model, prompt_length],
184
  outputs=output
 
201
  analyze_video_btn = gr.Button("Analyze Video")
202
  video_output = gr.Textbox(label="Video Analysis", lines=10)
203
 
204
+ # Use GPU-decorated wrapper functions directly
205
  analyze_image_btn.click(
206
+ describe_image_wrapper,
207
  inputs=[image_input, image_question],
208
  outputs=image_output
209
  )
210
 
211
  analyze_video_btn.click(
212
+ describe_video_wrapper,
213
  inputs=video_input,
214
  outputs=video_output
215
  )
216
 
217
  return demo
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  if __name__ == "__main__":
220
  demo = create_video_interface()
221
+ # Don't use share=True on Hugging Face Spaces
222
+ demo.launch()
llm_inference_video.py CHANGED
@@ -7,29 +7,34 @@ import tempfile
7
  from PIL import Image
8
  from groq import Groq
9
  from openai import OpenAI
10
- from vlm_captions import VLMCaptioning
11
 
12
  class VideoLLMInferenceNode:
13
- def __init__(self, vlm_captioner=None):
14
  """
15
- Initialize the VideoLLMInferenceNode
16
-
17
- Args:
18
- vlm_captioner: The already initialized VLMCaptioning instance to use
19
  """
20
- self.vlm = vlm_captioner
21
  self.sambanova_api_key = os.environ.get("SAMBANOVA_API_KEY", "")
22
  self.groq_api_key = os.environ.get("GROQ_API_KEY", "")
23
 
24
- self.groq_client = Groq(api_key=self.groq_api_key)
25
- self.sambanova_client = OpenAI(
26
- api_key=self.sambanova_api_key,
27
- base_url="https://api.sambanova.ai/v1",
28
- )
 
 
 
 
 
 
 
 
29
 
 
30
  def analyze_image(self, image_path: str, question: Optional[str] = None) -> str:
31
  """
32
- Analyze an image using the VLM model
33
 
34
  Args:
35
  image_path: Path to the image file
@@ -45,14 +50,17 @@ class VideoLLMInferenceNode:
45
  question = "Describe this image in detail."
46
 
47
  try:
48
- # Use the passed vlm_captioner instance
49
- return self.vlm.describe_image(image_path, question)
 
 
50
  except Exception as e:
51
  return f"Error analyzing image: {str(e)}"
52
 
 
53
  def analyze_video(self, video_path: str) -> str:
54
  """
55
- Analyze a video using the VLM model
56
 
57
  Args:
58
  video_path: Path to the video file
@@ -64,8 +72,10 @@ class VideoLLMInferenceNode:
64
  return "Please upload a video."
65
 
66
  try:
67
- # Use the passed vlm_captioner instance
68
- return self.vlm.describe_video(video_path)
 
 
69
  except Exception as e:
70
  return f"Error analyzing video: {str(e)}"
71
 
@@ -147,16 +157,36 @@ The prompt should be detailed and technical, specifically mentioning camera angl
147
  # Call the appropriate API based on provider
148
  try:
149
  if provider == "SambaNova":
150
- return self._call_sambanova_api(system_message, user_message, model)
 
 
 
151
  elif provider == "Groq":
152
- return self._call_groq_api(system_message, user_message, model)
 
 
 
153
  else:
154
  return "Unsupported provider. Please select SambaNova or Groq."
155
  except Exception as e:
156
  return f"Error generating prompt: {str(e)}"
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def _call_sambanova_api(self, system_message: str, user_message: str, model: str) -> str:
159
- """Call the SambaNova API for prompt generation"""
160
  if not self.sambanova_api_key:
161
  return "SambaNova API key not configured. Please set the SAMBANOVA_API_KEY environment variable."
162
 
@@ -182,8 +212,22 @@ The prompt should be detailed and technical, specifically mentioning camera angl
182
  else:
183
  return f"Error from SambaNova API: {response.status_code} - {response.text}"
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def _call_groq_api(self, system_message: str, user_message: str, model: str) -> str:
186
- """Call the Groq API for prompt generation"""
187
  if not self.groq_api_key:
188
  return "Groq API key not configured. Please set the GROQ_API_KEY environment variable."
189
 
 
7
  from PIL import Image
8
  from groq import Groq
9
  from openai import OpenAI
10
+ import spaces
11
 
12
  class VideoLLMInferenceNode:
13
+ def __init__(self):
14
  """
15
+ Initialize the VideoLLMInferenceNode without VLM captioning dependency
 
 
 
16
  """
 
17
  self.sambanova_api_key = os.environ.get("SAMBANOVA_API_KEY", "")
18
  self.groq_api_key = os.environ.get("GROQ_API_KEY", "")
19
 
20
+ # Initialize API clients if keys are available
21
+ if self.groq_api_key:
22
+ self.groq_client = Groq(api_key=self.groq_api_key)
23
+ else:
24
+ self.groq_client = None
25
+
26
+ if self.sambanova_api_key:
27
+ self.sambanova_client = OpenAI(
28
+ api_key=self.sambanova_api_key,
29
+ base_url="https://api.sambanova.ai/v1",
30
+ )
31
+ else:
32
+ self.sambanova_client = None
33
 
34
+ @spaces.GPU()
35
  def analyze_image(self, image_path: str, question: Optional[str] = None) -> str:
36
  """
37
+ Analyze an image using VLM model directly
38
 
39
  Args:
40
  image_path: Path to the image file
 
50
  question = "Describe this image in detail."
51
 
52
  try:
53
+ # Import and use VLMCaptioning within this GPU-scoped function
54
+ from app import get_vlm_captioner
55
+ vlm = get_vlm_captioner()
56
+ return vlm.describe_image(image_path, question)
57
  except Exception as e:
58
  return f"Error analyzing image: {str(e)}"
59
 
60
+ @spaces.GPU()
61
  def analyze_video(self, video_path: str) -> str:
62
  """
63
+ Analyze a video using VLM model directly
64
 
65
  Args:
66
  video_path: Path to the video file
 
72
  return "Please upload a video."
73
 
74
  try:
75
+ # Import and use VLMCaptioning within this GPU-scoped function
76
+ from app import get_vlm_captioner
77
+ vlm = get_vlm_captioner()
78
+ return vlm.describe_video(video_path)
79
  except Exception as e:
80
  return f"Error analyzing video: {str(e)}"
81
 
 
157
  # Call the appropriate API based on provider
158
  try:
159
  if provider == "SambaNova":
160
+ if self.sambanova_client:
161
+ return self._call_sambanova_client(system_message, user_message, model)
162
+ else:
163
+ return self._call_sambanova_api(system_message, user_message, model)
164
  elif provider == "Groq":
165
+ if self.groq_client:
166
+ return self._call_groq_client(system_message, user_message, model)
167
+ else:
168
+ return self._call_groq_api(system_message, user_message, model)
169
  else:
170
  return "Unsupported provider. Please select SambaNova or Groq."
171
  except Exception as e:
172
  return f"Error generating prompt: {str(e)}"
173
 
174
+ def _call_sambanova_client(self, system_message: str, user_message: str, model: str) -> str:
175
+ """Call the SambaNova API using the client library"""
176
+ try:
177
+ chat_completion = self.sambanova_client.chat.completions.create(
178
+ model=model,
179
+ messages=[
180
+ {"role": "system", "content": system_message},
181
+ {"role": "user", "content": user_message}
182
+ ]
183
+ )
184
+ return chat_completion.choices[0].message.content
185
+ except Exception as e:
186
+ return f"Error from SambaNova API: {str(e)}"
187
+
188
  def _call_sambanova_api(self, system_message: str, user_message: str, model: str) -> str:
189
+ """Call the SambaNova API using direct HTTP requests"""
190
  if not self.sambanova_api_key:
191
  return "SambaNova API key not configured. Please set the SAMBANOVA_API_KEY environment variable."
192
 
 
212
  else:
213
  return f"Error from SambaNova API: {response.status_code} - {response.text}"
214
 
215
+ def _call_groq_client(self, system_message: str, user_message: str, model: str) -> str:
216
+ """Call the Groq API using the client library"""
217
+ try:
218
+ chat_completion = self.groq_client.chat.completions.create(
219
+ model=model,
220
+ messages=[
221
+ {"role": "system", "content": system_message},
222
+ {"role": "user", "content": user_message}
223
+ ]
224
+ )
225
+ return chat_completion.choices[0].message.content
226
+ except Exception as e:
227
+ return f"Error from Groq API: {str(e)}"
228
+
229
  def _call_groq_api(self, system_message: str, user_message: str, model: str) -> str:
230
+ """Call the Groq API using direct HTTP requests"""
231
  if not self.groq_api_key:
232
  return "Groq API key not configured. Please set the GROQ_API_KEY environment variable."
233