Spaces:
Runtime error
Runtime error
# Import necessary libraries | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
import torch | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set device to GPU if available, otherwise CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained("mrm8488/falcoder-7b") | |
model = AutoModelForCausalLM.from_pretrained("mrm8488/falcoder-7b") | |
def generate_text(prompt, max_length, do_sample, temperature, top_k, top_p): | |
""" | |
Generates text completion given a prompt and specified parameters. | |
:param prompt: Input prompt for text generation. | |
:type prompt: str | |
:param max_length: Maximum length of generated text. | |
:type max_length: int | |
:param do_sample: Whether to use sampling for text generation. | |
:type do_sample: bool | |
:param temperature: Sampling temperature for text generation. | |
:type temperature: float | |
:param top_k: Value for top-k sampling. | |
:type top_k: int | |
:param top_p: Value for top-p sampling. | |
:type top_p: float | |
:return: Generated text completion. | |
:rtype: str | |
""" | |
# Format prompt | |
formatted_prompt = "\n" + prompt | |
if not ',' in prompt: | |
formatted_prompt += ',' | |
# Tokenize prompt and move to device | |
prompt = tokenizer(formatted_prompt, return_tensors='pt') | |
prompt = {key: value.to(device) for key, value in prompt.items()} | |
# Generate text completion using model and specified parameters | |
out = model.generate(**prompt, max_length=max_length, do_sample=do_sample, temperature=temperature, | |
no_repeat_ngram_size=3, top_k=top_k, top_p=top_p) | |
output = tokenizer.decode(out[0]) | |
clean_output = output.replace('\n', '\n') | |
# Log generated text completion | |
logger.info("Text generated: %s", clean_output) | |
return clean_output | |
# Define Gradio interface | |
custom_css = """ | |
.gradio-container { | |
background-color: #0D1525; | |
color:white | |
} | |
#orange-button { | |
background: #F26207 !important; | |
color: white; | |
} | |
.cm-gutters{ | |
border: none !important; | |
} | |
""" | |
def post_processing(prompt, completion): | |
""" | |
Formats generated text completion for display. | |
:param prompt: Input prompt for text generation. | |
:type prompt: str | |
:param completion: Generated text completion. | |
:type completion: str | |
:return: Formatted text completion. | |
:rtype: str | |
""" | |
return prompt + completion | |
def code_generation(prompt, max_new_tokens, temperature=0.2, seed=42, top_p=0.9, top_k=None, use_cache=True, repetition_penalty=1.0): | |
""" | |
Generates code completion given a prompt and specified parameters. | |
:param prompt: Input prompt for code generation. | |
:type prompt: str | |
:param max_new_tokens: Maximum number of tokens to generate. | |
:type max_new_tokens: int | |
:param temperature: Sampling temperature for code generation. | |
:type temperature: float | |
:param seed: Random seed for code generation. | |
:type seed: int | |
:param top_p: Value for top-p sampling. | |
:type top_p: float | |
:param top_k: Value for top-k sampling. | |
:type top_k: int | |
:param use_cache: Whether to use cache for code generation. | |
:type use_cache: bool | |
:param repetition_penalty: Value for repetition penalty. | |
:type repetition_penalty: float | |
:return: Generated code completion. | |
:rtype: str | |
""" | |
# Truncate prompt if too long | |
MAX_INPUT_TOKENS = 2048 | |
if len(prompt) > MAX_INPUT_TOKENS: | |
prompt = prompt[-MAX_INPUT_TOKENS:] | |
# Tokenize prompt and move to device | |
x = tokenizer.encode(prompt, return_tensors="pt", max_length=MAX_INPUT_TOKENS, truncation=True).to(device) | |
logger.info("Prompt shape: %s", x.shape) | |
# Generate code completion using model and specified parameters | |
set_seed(seed) | |
y = model.generate(x, | |
max_new_tokens=max_new_tokens, | |
temperature=temperature, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
top_p=top_p, | |
top_k=top_k, | |
use_cache=use_cache, | |
repetition_penalty=repetition_penalty | |
) | |
completion = tokenizer.decode(y[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
completion = completion[len(prompt):] | |
return post_processing(prompt, completion) | |
description = """ | |
### Falcoder | |
Falcoder is a GPT-2 model fine-tuned on Python code. It can be used for generating code completions given a prompt. | |
### Text Generation | |
Use the text generation section to generate text completions given a prompt. You can adjust the maximum length of the generated text, whether to use sampling, the sampling temperature, and the top-k and top-p values for sampling. | |
### Code Generation | |
Use the code generation section to generate code completions given a prompt. You can adjust the maximum number of tokens to generate, the sampling temperature, the random seed, the top-p and top-k values for sampling, whether to use cache, and the repetition penalty. | |
""" | |
demo = gr.Interface( | |
[generate_text, code_generation], | |
["textbox", "textbox"], | |
["textbox", "textbox"], | |
title="Falcoder", | |
description=description, | |
theme="compact", | |
layout="vertical", | |
css=custom_css | |
) | |
# Launch Gradio interface | |
demo.launch() |