codewithdark commited on
Commit
2e7e9a5
·
verified ·
1 Parent(s): d86d806

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -21
app.py CHANGED
@@ -1,40 +1,52 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer
4
 
5
- # Load the local model
6
  model_name = "codewithdark/latent-recurrent-depth-lm"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model.to(device).eval() # Set to evaluation mode
11
 
12
- # Define inference function
13
- def chat_with_model(input_text, model_choice):
14
- if model_choice == "Latent Recurrent Depth LM":
15
- input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
16
- with torch.no_grad():
17
- output = model.generate(input_ids, max_length=512)
18
- response = tokenizer.decode(output[0], skip_special_tokens=True)
19
- return response
20
- return "Model not available yet!"
21
 
22
- # Create Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
23
  with gr.Blocks() as demo:
24
  gr.Markdown("# 🤖 Chat with Latent Recurrent Depth LM")
 
 
 
25
 
26
- model_choice = gr.Radio(
27
- ["Latent Recurrent Depth LM"], # Add more models if needed
28
- label="Select Model",
29
- value="Latent Recurrent Depth LM"
30
- )
31
 
32
- text_input = gr.Textbox(label="Enter your message")
33
  submit_button = gr.Button("Generate Response")
34
  output_text = gr.Textbox(label="Model Response")
35
 
36
- submit_button.click(fn=chat_with_model, inputs=[text_input, model_choice], outputs=output_text)
 
 
 
 
37
 
38
- # Launch the Gradio app
39
  if __name__ == "__main__":
40
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModel
4
 
5
+ # Load tokenizer and model
6
  model_name = "codewithdark/latent-recurrent-depth-lm"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model.to(device).eval()
11
 
12
+ # Define function for inference
13
+ def chat_with_model(input_text, num_iterations, max_tokens, temperature, top_k):
14
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
 
 
 
 
 
 
15
 
16
+ with torch.no_grad():
17
+ generated_ids = model.generate(
18
+ input_ids,
19
+ max_length=max_tokens,
20
+ num_iterations=num_iterations, # Assuming the model supports it
21
+ temperature=temperature,
22
+ top_k=top_k
23
+ )
24
+
25
+ response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
26
+ return response
27
+
28
+ # Gradio UI
29
  with gr.Blocks() as demo:
30
  gr.Markdown("# 🤖 Chat with Latent Recurrent Depth LM")
31
+
32
+ with gr.Row():
33
+ text_input = gr.Textbox(label="Enter your message")
34
 
35
+ with gr.Row():
36
+ num_iterations = gr.Slider(1, 20, step=1, value=10, label="Number of Iterations")
37
+ max_tokens = gr.Slider(10, 200, step=10, value=50, label="Max Tokens")
38
+ temperature = gr.Slider(0.1, 1.0, step=0.1, value=0.5, label="Temperature")
39
+ top_k = gr.Slider(10, 100, step=10, value=50, label="Top-K Sampling")
40
 
 
41
  submit_button = gr.Button("Generate Response")
42
  output_text = gr.Textbox(label="Model Response")
43
 
44
+ submit_button.click(
45
+ fn=chat_with_model,
46
+ inputs=[text_input, num_iterations, max_tokens, temperature, top_k],
47
+ outputs=output_text
48
+ )
49
 
50
+ # Launch Gradio app
51
  if __name__ == "__main__":
52
  demo.launch()