Spaces:
Runtime error
Runtime error
File size: 5,533 Bytes
c12c1d4 19890e4 fbd4b06 c12c1d4 19890e4 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 fbd4b06 c12c1d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Import necessary libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import torch
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set device to GPU if available, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/falcoder-7b")
model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b")
def generate_text(prompt, max_length, do_sample, temperature, top_k, top_p):
"""
Generates text completion given a prompt and specified parameters.
:param prompt: Input prompt for text generation.
:type prompt: str
:param max_length: Maximum length of generated text.
:type max_length: int
:param do_sample: Whether to use sampling for text generation.
:type do_sample: bool
:param temperature: Sampling temperature for text generation.
:type temperature: float
:param top_k: Value for top-k sampling.
:type top_k: int
:param top_p: Value for top-p sampling.
:type top_p: float
:return: Generated text completion.
:rtype: str
"""
# Format prompt
formatted_prompt = "\n" + prompt
if not ',' in prompt:
formatted_prompt += ','
# Tokenize prompt and move to device
prompt = tokenizer(formatted_prompt, return_tensors='pt')
prompt = {key: value.to(device) for key, value in prompt.items()}
# Generate text completion using model and specified parameters
out = model.generate(**prompt, max_length=max_length, do_sample=do_sample, temperature=temperature,
no_repeat_ngram_size=3, top_k=top_k, top_p=top_p)
output = tokenizer.decode(out[0])
clean_output = output.replace('\n', '\n')
# Log generated text completion
logger.info("Text generated: %s", clean_output)
return clean_output
# Define Gradio interface
custom_css = """
.gradio-container {
background-color: #0D1525;
color:white
}
#orange-button {
background: #F26207 !important;
color: white;
}
.cm-gutters{
border: none !important;
}
"""
def post_processing(prompt, completion):
"""
Formats generated text completion for display.
:param prompt: Input prompt for text generation.
:type prompt: str
:param completion: Generated text completion.
:type completion: str
:return: Formatted text completion.
:rtype: str
"""
return prompt + completion
def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0):
"""
Generates code completion given a prompt and specified parameters.
:param prompt: Input prompt for code generation.
:type prompt: str
:param max_new_tokens: Maximum number of tokens to generate.
:type max_new_tokens: int
:param temperature: Sampling temperature for code generation.
:type temperature: float
:param seed: Random seed for code generation.
:type seed: int
:param top_p: Value for top-p sampling.
:type top_p: float
:param top_k: Value for top-k sampling.
:type top_k: int
:param use_cache: Whether to use cache for code generation.
:type use_cache: bool
:param repetition_penalty: Value for repetition penalty.
:type repetition_penalty: float
:return: Generated code completion.
:rtype: str
"""
# Truncate prompt if too long
MAX_INPUT_TOKENS = 2048
if len(prompt) > MAX_INPUT_TOKENS:
prompt = prompt[-MAX_INPUT_TOKENS:]
# Tokenize prompt and move to device
x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device)
logger.info("Prompt shape: %s", x.shape)
# Generate code completion using model and specified parameters
set_seed(seed)
y = model.generate(x,
max_new_tokens=max_new_tokens,
temperature=temperature,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
top_p=top_p,
top_k=top_k,
use_cache=use_cache,
repetition_penalty=repetition_penalty
)
completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
completion = completion[len(prompt):]
return post_processing(prompt, completion)
description = """
### Falcoder
Falcoder is a GPT-2 model fine-tuned on Python code. It can be used for generating code completions given a prompt.
### Text Generation
Use the text generation section to generate text completions given a prompt. You can adjust the maximum length of the generated text, whether to use sampling, the sampling temperature, and the top-k and top-p values for sampling.
### Code Generation
Use the code generation section to generate code completions given a prompt. You can adjust the maximum number of tokens to generate, the sampling temperature, the random seed, the top-p and top-k values for sampling, whether to use cache, and the repetition penalty.
"""
demo = gr.Interface(
[generate_text, code_generation],
["textbox", "textbox"],
["textbox", "textbox"],
title="Falcoder",
description=description,
theme="compact",
layout="vertical",
css=custom_css
)
# Launch Gradio interface
demo.launch() |