gokaygokay commited on
Commit
1e26e1c
·
1 Parent(s): 4269cb2

fix models

Browse files
Files changed (2) hide show
  1. app.py +6 -13
  2. llm_inference_video.py +3 -12
app.py CHANGED
@@ -46,18 +46,17 @@ def create_video_interface():
46
  custom_elements = gr.Textbox(label="Custom Elements", scale=3)
47
  with gr.Column(scale=1):
48
  provider = gr.Dropdown(
49
- choices=["Hugging Face", "Groq", "SambaNova"],
50
- value="Hugging Face",
51
  label="LLM Provider"
52
  )
53
  model = gr.Dropdown(
54
  choices=[
55
- "Qwen/Qwen2.5-72B-Instruct",
56
- "meta-llama/Meta-Llama-3.1-70B-Instruct",
57
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
58
- "mistralai/Mistral-7B-Instruct-v0.3"
59
  ],
60
- value="Qwen/Qwen2.5-72B-Instruct",
61
  label="Model"
62
  )
63
 
@@ -66,12 +65,6 @@ def create_video_interface():
66
 
67
  def update_models(provider):
68
  models = {
69
- "Hugging Face": [
70
- "Qwen/Qwen2.5-72B-Instruct",
71
- "meta-llama/Meta-Llama-3.1-70B-Instruct",
72
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
73
- "mistralai/Mistral-7B-Instruct-v0.3"
74
- ],
75
  "Groq": ["llama-3.3-70b-versatile"],
76
  "SambaNova": [
77
  "Meta-Llama-3.1-70B-Instruct",
 
46
  custom_elements = gr.Textbox(label="Custom Elements", scale=3)
47
  with gr.Column(scale=1):
48
  provider = gr.Dropdown(
49
+ choices=["SambaNova", "Groq"],
50
+ value="SambaNova",
51
  label="LLM Provider"
52
  )
53
  model = gr.Dropdown(
54
  choices=[
55
+ "Meta-Llama-3.1-70B-Instruct",
56
+ "Meta-Llama-3.1-405B-Instruct",
57
+ "Meta-Llama-3.1-8B-Instruct"
 
58
  ],
59
+ value="Meta-Llama-3.1-70B-Instruct",
60
  label="Model"
61
  )
62
 
 
65
 
66
  def update_models(provider):
67
  models = {
 
 
 
 
 
 
68
  "Groq": ["llama-3.3-70b-versatile"],
69
  "SambaNova": [
70
  "Meta-Llama-3.1-70B-Instruct",
llm_inference_video.py CHANGED
@@ -6,14 +6,9 @@ from gradio_client import Client
6
 
7
  class VideoLLMInferenceNode:
8
  def __init__(self):
9
- self.huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
10
  self.groq_api_key = os.getenv("GROQ_API_KEY")
11
  self.sambanova_api_key = os.getenv("SAMBANOVA_API_KEY")
12
 
13
- self.huggingface_client = OpenAI(
14
- base_url="https://api-inference.huggingface.co/v1/",
15
- api_key=self.huggingface_token,
16
- )
17
  self.groq_client = Groq(api_key=self.groq_api_key)
18
  self.sambanova_client = OpenAI(
19
  api_key=self.sambanova_api_key,
@@ -28,8 +23,7 @@ class VideoLLMInferenceNode:
28
  pacing,
29
  special_effects,
30
  custom_elements,
31
- provider="Hugging Face",
32
-
33
  model=None
34
  ):
35
  try:
@@ -57,13 +51,10 @@ class VideoLLMInferenceNode:
57
  Avoid technical specifications or shot-by-shot breakdowns. Instead, create a flowing, descriptive narrative that captures the essence of the video."""
58
 
59
  # Select provider
60
- if provider == "Hugging Face":
61
- client = self.huggingface_client
62
- model = model or "meta-llama/Meta-Llama-3.1-70B-Instruct"
63
- elif provider == "Groq":
64
  client = self.groq_client
65
  model = model or "llama-3.3-70b-versatile"
66
- elif provider == "SambaNova":
67
  client = self.sambanova_client
68
  model = model or "Meta-Llama-3.1-70B-Instruct"
69
 
 
6
 
7
  class VideoLLMInferenceNode:
8
  def __init__(self):
 
9
  self.groq_api_key = os.getenv("GROQ_API_KEY")
10
  self.sambanova_api_key = os.getenv("SAMBANOVA_API_KEY")
11
 
 
 
 
 
12
  self.groq_client = Groq(api_key=self.groq_api_key)
13
  self.sambanova_client = OpenAI(
14
  api_key=self.sambanova_api_key,
 
23
  pacing,
24
  special_effects,
25
  custom_elements,
26
+ provider="SambaNova",
 
27
  model=None
28
  ):
29
  try:
 
51
  Avoid technical specifications or shot-by-shot breakdowns. Instead, create a flowing, descriptive narrative that captures the essence of the video."""
52
 
53
  # Select provider
54
+ if provider == "Groq":
 
 
 
55
  client = self.groq_client
56
  model = model or "llama-3.3-70b-versatile"
57
+ else: # SambaNova as default
58
  client = self.sambanova_client
59
  model = model or "Meta-Llama-3.1-70B-Instruct"
60