Commit adf7dca2 authored by Dennis Willers's avatar Dennis Willers 🏀

Erweiterung um Optimierungverfahren und Aktivierungsfunktionen.

parent 54ccd927
import yaml
from datetime import datetime from datetime import datetime
import yaml
from src.data.copyData import copy_images_with_exclusion from src.data.copyData import copy_images_with_exclusion
from src.data.modelData import get_test_data, get_train_data, define_augmentation_rules from src.data.modelData import get_test_data, get_train_data, define_augmentation_rules
from src.enum.activierungsfunktionEnum import Aktivierungsfunktion from src.enum.activierungsfunktionEnum import Aktivierungsfunktion
...@@ -12,7 +11,6 @@ from src.knn.defineKNN import define_model ...@@ -12,7 +11,6 @@ from src.knn.defineKNN import define_model
from src.result.createExcelFile import create_excel_result, get_excel_workbook, get_excel_worksheet, save_excel, \ from src.result.createExcelFile import create_excel_result, get_excel_workbook, get_excel_worksheet, save_excel, \
average_evaluate_cross_validation average_evaluate_cross_validation
from src.result.customCallback import CustomCallback from src.result.customCallback import CustomCallback
from src.result.plotResult import plot_values
def run_cross_validation(): def run_cross_validation():
...@@ -23,10 +21,14 @@ def run_cross_validation(): ...@@ -23,10 +21,14 @@ def run_cross_validation():
# Initialisieren der KNN-Modellbau Eigenschaften # Initialisieren der KNN-Modellbau Eigenschaften
for optimization_method in Optimierungsverfahren: for optimization_method in Optimierungsverfahren:
for activation_function_1 in Aktivierungsfunktion: for activation_function_2 in Aktivierungsfunktion:
if activation_function_1.name in config['knn']['exception_activation_funktion_1']: if activation_function_2.name in config['knn']['exception_activation_funktion_1']:
continue continue
for activation_function_128 in Aktivierungsfunktion: for activation_function_128 in Aktivierungsfunktion:
if check_if_kombination_is_not_allowed(
optimization_method.name, activation_function_128.name, activation_function_2.name, config
):
continue
model_evaluate_metrics = [] model_evaluate_metrics = []
training_duration_models = [] training_duration_models = []
workbook = get_excel_workbook() workbook = get_excel_workbook()
...@@ -34,7 +36,7 @@ def run_cross_validation(): ...@@ -34,7 +36,7 @@ def run_cross_validation():
r = 1 r = 1
config_knn = ConfigKNN( config_knn = ConfigKNN(
excluded_folder=Markt.Kein_Markt, excluded_folder=Markt.Kein_Markt,
activation_function_1_units=activation_function_1, activation_function_1_units=activation_function_2,
activation_function_128_units=activation_function_128, activation_function_128_units=activation_function_128,
optimization_method=optimization_method optimization_method=optimization_method
) )
...@@ -98,5 +100,13 @@ def run_model(config, config_knn, worksheet, r): ...@@ -98,5 +100,13 @@ def run_model(config, config_knn, worksheet, r):
return worksheet, r, evaluate_metrics, training_duration_model return worksheet, r, evaluate_metrics, training_duration_model
def check_if_kombination_is_not_allowed(opt, act128, act2, config):
currentKombination = [opt, act128, act2]
ignoreKombinationList = config['knn']['ignoreKnnCombinations']
if currentKombination in ignoreKombinationList:
return True
return False
# entry point # entry point
run_cross_validation() run_cross_validation()
\ No newline at end of file
bilder: bilder:
# original_path: "assets/Bilder/Datengrundlage-Reduziert-Test/" # original_path: "assets/Bilder/Datengrundlage-Reduziert-Test/"
# original_path: "assets/Bilder/Datengrundlage/" original_path: "assets/Bilder/Datengrundlage/"
original_path: "assets/Bilder/Datengrundlage-Augmentiert/" # original_path: "assets/Bilder/Datengrundlage-Augmentiert/"
knn_path: "assets/Bilder/AktuelleTrainingsUndTestdaten/" knn_path: "assets/Bilder/AktuelleTrainingsUndTestdaten/"
knn: knn:
epochs: 10 epochs: 10
exception_activation_funktion_1: exception_activation_funktion_1:
['ReLU'] ['ReLU']
ignoreKnnCombinations:
[
['SGD','ReLU','sigmoid'],
['SGD','sigmoid','sigmoid'],
['Adam','ReLU','sigmoid'],
['Adam','sigmoid','sigmoid']
]
result: result:
plot_path: "ressources/results/plot/" plot_path: "ressources/results/plot/"
excel_path: "ressources/results/excel/" excel_path: "ressources/results/excel/"
...@@ -4,3 +4,5 @@ from enum import Enum ...@@ -4,3 +4,5 @@ from enum import Enum
class Aktivierungsfunktion(Enum): class Aktivierungsfunktion(Enum):
ReLU = 0, ReLU = 0,
sigmoid = 1, sigmoid = 1,
tanh = 2,
softmax = 3
...@@ -3,4 +3,5 @@ from enum import Enum ...@@ -3,4 +3,5 @@ from enum import Enum
class Optimierungsverfahren(Enum): class Optimierungsverfahren(Enum):
SGD = 0, SGD = 0,
Adam = 1 Adam = 1,
Ftrl = 2
...@@ -38,6 +38,10 @@ def get_next_layer(units, activation_function, flat_or_dense): ...@@ -38,6 +38,10 @@ def get_next_layer(units, activation_function, flat_or_dense):
kernel_initializer='he_uniform')(flat_or_dense) kernel_initializer='he_uniform')(flat_or_dense)
if activation_function == Aktivierungsfunktion.sigmoid: if activation_function == Aktivierungsfunktion.sigmoid:
return tf.keras.layers.Dense(units, activation=activation_function.name)(flat_or_dense) return tf.keras.layers.Dense(units, activation=activation_function.name)(flat_or_dense)
if activation_function == Aktivierungsfunktion.tanh:
return tf.keras.layers.Dense(units, activation=activation_function.name)(flat_or_dense)
if activation_function == Aktivierungsfunktion.softmax:
return tf.keras.layers.Dense(units, activation=activation_function.name)(flat_or_dense)
return None return None
...@@ -46,5 +50,6 @@ def get_optimization_method(optimization_method): ...@@ -46,5 +50,6 @@ def get_optimization_method(optimization_method):
return tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9) return tf.keras.optimizers.SGD(learning_rate=0.001, momentum=0.9)
if optimization_method == Optimierungsverfahren.Adam: if optimization_method == Optimierungsverfahren.Adam:
return tf.keras.optimizers.Adam(learning_rate=0.001) return tf.keras.optimizers.Adam(learning_rate=0.001)
if optimization_method == Optimierungsverfahren.Ftrl:
return tf.keras.optimizers.Ftrl(learning_rate=0.001)
return None return None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment