Various fixes

This commit is contained in:
elijah 2024-07-19 03:12:21 +02:00
parent cd5ad0b61a
commit 7f0ee0e8c9
4 changed files with 397 additions and 279 deletions

63
app.py
View File

@ -1,57 +1,69 @@
from flask import Flask, request, jsonify, render_template, Response, stream_with_context
from transformers import AutoTokenizer
import os import os
import requests from flask import Flask, request, jsonify, render_template, Response, stream_with_context
from flask_limiter import Limiter from flask_limiter import Limiter
from flask_limiter.util import get_remote_address from flask_limiter.util import get_remote_address
from transformers import AutoTokenizer
import requests
import logging
app = Flask(__name__) app = Flask(__name__)
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize rate limiter # Initialize rate limiter
limiter = Limiter( limiter = Limiter(
get_remote_address, get_remote_address,
app=app, app=app,
storage_uri="memory://" storage_uri="memory://",
default_limits=[os.getenv('RATE_LIMIT', '15 per minute')]
) )
# Load the tokenizer # Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(os.environ.get('TOKENIZER', 'gpt2')) tokenizer = AutoTokenizer.from_pretrained(os.getenv('TOKENIZER', 'gpt2'))
api_url = os.environ.get('API_URL', 'https://api.openai.com/v1') # API configuration
api_key = os.environ.get('API_KEY') API_URL = os.getenv('API_URL', 'https://api.openai.com/v1')
api_model = os.environ.get('API_MODEL', 'gpt-3.5-turbo') API_KEY = os.getenv('API_KEY')
temperature = int(os.environ.get('TEMPERATURE', 0)) API_MODEL = os.getenv('API_MODEL', 'gpt-3.5-turbo')
TEMPERATURE = float(os.getenv('TEMPERATURE', 0))
logger.info(f"Chat initialized using endpoint: {API_URL}, model: {API_MODEL}, temperature: {TEMPERATURE}")
@app.route('/v1/tokenizer/count', methods=['POST']) @app.route('/v1/tokenizer/count', methods=['POST'])
def token_count(): def token_count():
try:
data = request.json data = request.json
messages = data.get('messages', []) messages = data.get('messages', [])
full_text = " ".join([f"{msg['role']}: {msg['content']}" for msg in messages]) full_text = " ".join([f"{msg['role']}: {msg['content']}" for msg in messages])
tokens = tokenizer.encode(full_text) tokens = tokenizer.encode(full_text)
token_count = len(tokens) return jsonify({"token_count": len(tokens)})
return jsonify({"token_count": token_count}) except Exception as e:
logger.error(f"Error in token_count: {str(e)}")
return jsonify({"error": "Invalid request"}), 400
@app.route('/v1/chat/completions', methods=['POST']) @app.route('/v1/chat/completions', methods=['POST'])
@limiter.limit(os.environ.get('RATE_LIMIT', '20/minute')) @limiter.limit(os.getenv('RATE_LIMIT', '15/minute'))
def proxy_chat_completions(): def proxy_chat_completions():
headers = { headers = {
'Authorization': f'Bearer {api_key}', 'Authorization': f'Bearer {API_KEY}',
'Content-Type': 'application/json' 'Content-Type': 'application/json'
} }
try:
request_data = request.json request_data = request.json
request_data['model'] = API_MODEL
request_data['model'] = api_model request_data['temperature'] = TEMPERATURE
request_data['temperature'] = temperature
request_data['stream'] = True request_data['stream'] = True
response = requests.post(f"{api_url}/chat/completions", response = requests.post(f"{API_URL}/chat/completions",
json=request_data, json=request_data,
headers=headers, headers=headers,
stream=True) stream=True)
# Stream the response back to the client response.raise_for_status()
def generate(): def generate():
for chunk in response.iter_content(chunk_size=8): for chunk in response.iter_content(chunk_size=8):
if chunk: if chunk:
@ -60,6 +72,13 @@ def proxy_chat_completions():
return Response(stream_with_context(generate()), return Response(stream_with_context(generate()),
content_type=response.headers['content-type']) content_type=response.headers['content-type'])
except requests.RequestException as e:
logger.error(f"API request failed: {str(e)}")
return jsonify({"error": "Failed to connect to the API"}), 503
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
return jsonify({"error": "An unexpected error occurred"}), 500
@app.route('/') @app.route('/')
def index(): def index():
return render_template('index.html') return render_template('index.html')
@ -68,5 +87,9 @@ def index():
def serve_static(filename): def serve_static(filename):
return app.send_static_file(filename) return app.send_static_file(filename)
@app.errorhandler(429)
def ratelimit_handler(e):
return jsonify({"error": "Rate limit exceeded. Please try again later."}), 429
if __name__ == '__main__': if __name__ == '__main__':
app.run(debug=False, port=5000) app.run(debug=False, port=int(os.getenv('PORT', 5000)))

View File

@ -366,3 +366,51 @@ p {
transform: rotate(90deg); transform: rotate(90deg);
width: 2rem; width: 2rem;
} }
.error-toast {
position: fixed;
bottom: 20px;
right: 20px;
background-color: var(--red-color);
color: var(--foreground-color);
padding: 1rem;
border-radius: 10px;
max-width: 300px;
box-shadow: 0 8px 15px rgba(0, 0, 0, 0.2);
z-index: 1000;
cursor: pointer;
transition: all 0.3s ease;
}
.error-toast:hover {
transform: translateY(-5px);
box-shadow: 0 12px 20px rgba(0, 0, 0, 0.3);
}
@media (max-width: 640px) {
.error-toast {
left: 20px;
right: 20px;
max-width: none;
}
}
@keyframes shake {
0%,
100% {
transform: translateX(0);
}
10%,
30%,
50%,
70%,
90% {
transform: translateX(-5px);
}
20%,
40%,
60%,
80% {
transform: translateX(5px);
}
}

View File

@ -12,7 +12,7 @@ document.addEventListener("alpine:init", () => {
home: 0, home: 0,
generating: false, generating: false,
endpoint: window.location.origin + "/v1", endpoint: window.location.origin + "/v1",
model: "llama3-8b-8192", // This doesen't matter anymore as the backend handles it now model: "llama3-8b-8192", // This doesn't matter anymore as the backend handles it now
stopToken: "<|eot_id|>", // We may need this for some models stopToken: "<|eot_id|>", // We may need this for some models
// performance tracking // performance tracking
@ -20,6 +20,9 @@ document.addEventListener("alpine:init", () => {
tokens_per_second: 0, tokens_per_second: 0,
total_tokens: 0, total_tokens: 0,
// New property for error messages
errorMessage: null,
removeHistory(cstate) { removeHistory(cstate) {
const index = this.histories.findIndex((state) => { const index = this.histories.findIndex((state) => {
return state.time === cstate.time; return state.time === cstate.time;
@ -37,6 +40,7 @@ document.addEventListener("alpine:init", () => {
if (this.generating) return; if (this.generating) return;
this.generating = true; this.generating = true;
this.errorMessage = null; // Clear any previous error messages
if (this.home === 0) this.home = 1; if (this.home === 0) this.home = 1;
// ensure that going back in history will go back to home // ensure that going back in history will go back to home
@ -56,6 +60,7 @@ document.addEventListener("alpine:init", () => {
let tokens = 0; let tokens = 0;
this.tokens_per_second = 0; this.tokens_per_second = 0;
try {
// start receiving server sent events // start receiving server sent events
let gottenFirstChunk = false; let gottenFirstChunk = false;
for await (const chunk of this.openaiChatCompletion( for await (const chunk of this.openaiChatCompletion(
@ -67,7 +72,8 @@ document.addEventListener("alpine:init", () => {
} }
// add chunk to the last message // add chunk to the last message
this.cstate.messages[this.cstate.messages.length - 1].content += chunk; this.cstate.messages[this.cstate.messages.length - 1].content +=
chunk;
// calculate performance tracking // calculate performance tracking
tokens += 1; tokens += 1;
@ -96,8 +102,14 @@ document.addEventListener("alpine:init", () => {
} }
// update in local storage // update in local storage
localStorage.setItem("histories", JSON.stringify(this.histories)); localStorage.setItem("histories", JSON.stringify(this.histories));
} catch (error) {
console.error("Error in handleSend:", error);
this.showError(
error.message || "An error occurred processing your request.",
);
} finally {
this.generating = false; this.generating = false;
}
}, },
async handleEnter(event) { async handleEnter(event) {
@ -108,21 +120,36 @@ document.addEventListener("alpine:init", () => {
} }
}, },
showError(message) {
this.errorMessage = message;
setTimeout(() => {
this.errorMessage = null;
}, 3000); // Hide after 5 seconds
},
updateTotalTokens(messages) { updateTotalTokens(messages) {
fetch(`${window.location.origin}/v1/tokenizer/count`, { fetch(`${window.location.origin}/v1/tokenizer/count`, {
method: "POST", method: "POST",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify({ messages }), body: JSON.stringify({ messages }),
}) })
.then((response) => response.json()) .then((response) => {
if (!response.ok) {
throw new Error("Failed to count tokens");
}
return response.json();
})
.then((data) => { .then((data) => {
this.total_tokens = data.token_count; this.total_tokens = data.token_count;
}) })
.catch(console.error); .catch((error) => {
console.error("Error updating total tokens:", error);
this.showError("Failed to update token count. Please try again.");
});
}, },
async *openaiChatCompletion(messages) { async *openaiChatCompletion(messages) {
// stream response try {
const response = await fetch(`${this.endpoint}/chat/completions`, { const response = await fetch(`${this.endpoint}/chat/completions`, {
method: "POST", method: "POST",
headers: { headers: {
@ -136,8 +163,10 @@ document.addEventListener("alpine:init", () => {
stop: [this.stopToken], stop: [this.stopToken],
}), }),
}); });
if (!response.ok) { if (!response.ok) {
throw new Error("Failed to fetch"); const errorData = await response.json();
throw new Error(errorData.error || "Failed to fetch");
} }
const reader = response.body.getReader(); const reader = response.body.getReader();
@ -146,9 +175,8 @@ document.addEventListener("alpine:init", () => {
while (true) { while (true) {
const { done, value } = await reader.read(); const { done, value } = await reader.read();
if (done) { if (done) break;
break;
}
buffer += decoder.decode(value, { stream: true }); buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n"); const lines = buffer.split("\n");
buffer = lines.pop(); buffer = lines.pop();
@ -156,9 +184,8 @@ document.addEventListener("alpine:init", () => {
for (const line of lines) { for (const line of lines) {
if (line.startsWith("data: ")) { if (line.startsWith("data: ")) {
const data = line.slice(6); const data = line.slice(6);
if (data === "[DONE]") { if (data === "[DONE]") return;
return;
}
try { try {
const json = JSON.parse(data); const json = JSON.parse(data);
if (json.choices && json.choices[0].delta.content) { if (json.choices && json.choices[0].delta.content) {
@ -170,6 +197,14 @@ document.addEventListener("alpine:init", () => {
} }
} }
} }
} catch (error) {
console.error("Error in openaiChatCompletion:", error);
this.showError(
error.message ||
"An error occurred while communicating with the server.",
);
throw error;
}
}, },
})); }));
}); });

View File

@ -1,6 +1,6 @@
<!DOCTYPE html> <!DOCTYPE html>
<head> <head>
<title>cchat</title> <title>cchat</title>
<meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="viewport" content="width=device-width, initial-scale=1">
<link rel="icon" href="{{ url_for('static', filename='favicon.ico') }}" type="image/svg+xml"> <link rel="icon" href="{{ url_for('static', filename='favicon.ico') }}" type="image/svg+xml">
@ -30,9 +30,9 @@
<link rel="stylesheet" href="{{ url_for('static', filename='css/index.css') }}"> <link rel="stylesheet" href="{{ url_for('static', filename='css/index.css') }}">
<link rel="stylesheet" href="{{ url_for('static', filename='css/common.css') }}"> <link rel="stylesheet" href="{{ url_for('static', filename='css/common.css') }}">
</head> </head>
<body> <body>
<main x-data="state" x-init="console.log(endpoint)"> <main x-data="state" x-init="console.log(endpoint)">
<button class="new-chat-button" @click=" <button class="new-chat-button" @click="
home = 0; home = 0;
@ -172,7 +172,19 @@
</button> </button>
</div> </div>
</div> </div>
<div x-show="errorMessage"
x-transition:enter="transition ease-out duration-500"
x-transition:enter-start="opacity-0 transform translate-y-10 scale-95"
x-transition:enter-end="opacity-100 transform translate-y-0 scale-100"
x-transition:leave="transition ease-in duration-300"
x-transition:leave-start="opacity-100 transform translate-y-0 scale-100"
x-transition:leave-end="opacity-0 transform translate-y-10 scale-95"
@click="errorMessage = null"
class="error-toast"
x-init="$el.style.animation = 'shake 0.82s cubic-bezier(.36,.07,.19,.97) both'">
<div x-text="errorMessage"></div>
</div>
</main> </main>
</body> </body>
</html> </html>