kimhyunwoo commited on
Commit
d7fcb3e
·
verified ·
1 Parent(s): ecde4ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -63
app.py CHANGED
@@ -1,64 +1,109 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
+
4
+ # --- Model Loading (Do this only once, outside the function) ---
5
+
6
+ # Option 1: Pipeline (High-Level, Easier)
7
+ use_pipeline = True # Set to False to use the manual method
8
+
9
+ if use_pipeline:
10
+ pipe = pipeline("text-generation", model="kakaocorp/kanana-nano-2.1b-base")
11
+ else:
12
+ # Option 2: Manual Tokenizer and Model (More Control)
13
+ tokenizer = AutoTokenizer.from_pretrained("kakaocorp/kanana-nano-2.1b-base")
14
+ model = AutoModelForCausalLM.from_pretrained("kakaocorp/kanana-nano-2.1b-base")
15
+ # Move model to GPU if available
16
+ if model.device.type != 'cuda' and torch.cuda.is_available():
17
+ model = model.to("cuda")
18
+ print("Model moved to CUDA")
19
+
20
+ # --- Generation Function ---
21
+
22
+ def generate_text(prompt, max_length=50, temperature=1.0, top_k=50, top_p=1.0, no_repeat_ngram_size=0, num_return_sequences=1):
23
+ """Generates text based on the given prompt and parameters."""
24
+
25
+ if use_pipeline:
26
+ messages = [{"role": "user", "content": prompt}] # Format for pipeline
27
+ try:
28
+ result = pipe(
29
+ messages,
30
+ max_length=max_length,
31
+ temperature=temperature,
32
+ top_k=top_k,
33
+ top_p=top_p,
34
+ no_repeat_ngram_size=no_repeat_ngram_size,
35
+ num_return_sequences=num_return_sequences,
36
+ return_full_text=False, # Important: We only want generated text
37
+ pad_token_id=pipe.tokenizer.eos_token_id # Prevent warning, pipeline knows the EOS token
38
+ )
39
+ # Pipeline returns a list of dictionaries, each with 'generated_text'
40
+ return "\n\n".join([res['generated_text'] for res in result])
41
+
42
+ except Exception as e:
43
+ return f"Error during generation: {e}"
44
+
45
+ else: # Manual method
46
+ try:
47
+ inputs = tokenizer(prompt, return_tensors="pt")
48
+ # Move input tensors to the same device as the model
49
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
50
+
51
+ outputs = model.generate(
52
+ **inputs,
53
+ max_length=max_length,
54
+ temperature=temperature,
55
+ top_k=top_k,
56
+ top_p=top_p,
57
+ no_repeat_ngram_size=no_repeat_ngram_size,
58
+ num_return_sequences=num_return_sequences,
59
+ pad_token_id=tokenizer.eos_token_id, # Ensure padding is correct
60
+ do_sample=True # Ensure sampling happens.
61
+ )
62
+
63
+ generated_texts = []
64
+ for i in range(outputs.shape[0]):
65
+ generated_text = tokenizer.decode(outputs[i], skip_special_tokens=True)
66
+ generated_texts.append(generated_text)
67
+
68
+ return "\n\n".join(generated_texts)
69
+ except Exception as e:
70
+ return f"Error during generation: {e}"
71
+
72
+
73
+
74
+ # --- Gradio Interface ---
75
+
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# Text Generation with kakaocorp/kanana-nano-2.1b-base")
78
+
79
+ with gr.Row():
80
+ with gr.Column():
81
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
82
+ with gr.Accordion("Generation Parameters", open=False):
83
+ max_length_slider = gr.Slider(label="Max Length", minimum=10, maximum=512, value=50, step=1)
84
+ temperature_slider = gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, value=1.0, step=0.1)
85
+ top_k_slider = gr.Slider(label="Top K", minimum=0, maximum=100, value=50, step=1)
86
+ top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=1.0, step=0.05)
87
+ no_repeat_ngram_size_slider = gr.Slider(label="No Repeat N-gram Size", minimum=0, maximum=10, value=0, step=1) # Add the slider
88
+ num_return_sequences_slider = gr.Slider(label="Number of Return Sequences", minimum=1, maximum=5, value=1, step=1)
89
+
90
+ generate_button = gr.Button("Generate")
91
+
92
+ with gr.Column():
93
+ output_text = gr.Textbox(label="Generated Text", readonly=True)
94
+
95
+ generate_button.click(
96
+ generate_text,
97
+ inputs=[
98
+ prompt_input,
99
+ max_length_slider,
100
+ temperature_slider,
101
+ top_k_slider,
102
+ top_p_slider,
103
+ no_repeat_ngram_size_slider,
104
+ num_return_sequences_slider
105
+ ],
106
+ outputs=output_text,
107
+ )
108
+
109
+ demo.launch(share=True)