|
import os |
|
from fastapi import FastAPI, File, Form, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel |
|
from groq import Groq |
|
import io |
|
|
|
|
|
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY") |
|
client = Groq(api_key=os.environ["GROQ_API_KEY"]) |
|
|
|
app = FastAPI() |
|
|
|
|
|
class TranscriptionResponse(BaseModel): |
|
transcription: str |
|
|
|
class FeedbackResponse(BaseModel): |
|
grammar_feedback: str |
|
vocabulary_feedback: str |
|
grammar_score: int |
|
vocabulary_score: int |
|
|
|
@app.get("/") |
|
async def index(): |
|
return {"message": "Welcome to the Audio Transcription API!"} |
|
|
|
@app.post("/transcribe") |
|
async def transcribe_audio(audio_data: bytes = File(...), language: str = Form(...)): |
|
try: |
|
|
|
transcription = client.audio.transcriptions.create( |
|
file=("audio.wav", audio_data), |
|
model="whisper-large-v3", |
|
prompt="Transcribe the audio accurately based on the selected language.", |
|
response_format="text", |
|
language=language, |
|
) |
|
return JSONResponse(content=TranscriptionResponse(transcription=transcription).dict()) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/check_grammar") |
|
async def check_grammar(transcription: str = Form(...), language: str = Form(...)): |
|
if not transcription or not language: |
|
raise HTTPException(status_code=400, detail="Missing transcription or language selection") |
|
|
|
try: |
|
|
|
grammar_prompt = ( |
|
f"Briefly check the grammar of the following text in {language}: {transcription}. " |
|
"Identify any word that does not belong to the selected language and flag it. Based on the number of incorrect words, " |
|
"also check the grammar deeply and carefully. Provide a score from 1 to 10 based on the grammar accuracy, reducing points for incorrect words, " |
|
"and make sure to output the score on a new line after two line breaks like \"SCORE=\"." |
|
) |
|
|
|
grammar_check_response = client.chat.completions.create( |
|
model="llama3-groq-70b-8192-tool-use-preview", |
|
messages=[{"role": "user", "content": grammar_prompt}] |
|
) |
|
grammar_feedback = grammar_check_response.choices[0].message.content.strip() |
|
|
|
|
|
vocabulary_prompt = ( |
|
f"Check the vocabulary accuracy of the following text in {language}: {transcription}. " |
|
"Identify any word that does not belong to the selected language and flag it. Based on the number of incorrect words, " |
|
"also check the grammar deeply and carefully. Provide a score from 1 to 10 based on the vocabulary accuracy, reducing points for incorrect words, " |
|
"and make sure to output the score on a new line after two line breaks like \"SCORE=\"." |
|
) |
|
|
|
vocabulary_check_response = client.chat.completions.create( |
|
model="llama-3.1-70b-versatile", |
|
messages=[{"role": "user", "content": vocabulary_prompt}] |
|
) |
|
vocabulary_feedback = vocabulary_check_response.choices[0].message.content.strip() |
|
|
|
|
|
grammar_score = calculate_score(grammar_feedback) |
|
vocabulary_score = calculate_score(vocabulary_feedback) |
|
|
|
return JSONResponse(content=FeedbackResponse( |
|
grammar_feedback=grammar_feedback, |
|
vocabulary_feedback=vocabulary_feedback, |
|
grammar_score=grammar_score, |
|
vocabulary_score=vocabulary_score |
|
).dict()) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
def calculate_score(feedback): |
|
""" |
|
Calculate score based on feedback content. |
|
This function searches for the keyword 'SCORE=' or similar variations |
|
(SCORE:, score:, etc.) and extracts the score value. |
|
""" |
|
import re |
|
match = re.search(r'(SCORE=|score=|SCORE:|score:|SCORE = )\s*(\d+)', feedback) |
|
|
|
if match: |
|
|
|
return int(match.group(2)) |
|
|
|
|
|
return 0 |
|
|